Payment checkout test, rate limit resetting on tier change; failing

pull/600/head
binwiederhier 2023-01-25 22:26:04 -05:00
parent 236254d907
commit 593e0748a8
8 changed files with 257 additions and 42 deletions

View File

@ -46,8 +46,8 @@ const (
DefaultVisitorRequestLimitReplenish = 5 * time.Second DefaultVisitorRequestLimitReplenish = 5 * time.Second
DefaultVisitorEmailLimitBurst = 16 DefaultVisitorEmailLimitBurst = 16
DefaultVisitorEmailLimitReplenish = time.Hour DefaultVisitorEmailLimitReplenish = time.Hour
DefaultVisitorAccountCreateLimitBurst = 3 DefaultVisitorAccountCreationLimitBurst = 3
DefaultVisitorAccountCreateLimitReplenish = 24 * time.Hour DefaultVisitorAccountCreationLimitReplenish = 24 * time.Hour
DefaultVisitorAttachmentTotalSizeLimit = 100 * 1024 * 1024 // 100 MB DefaultVisitorAttachmentTotalSizeLimit = 100 * 1024 * 1024 // 100 MB
DefaultVisitorAttachmentDailyBandwidthLimit = 500 * 1024 * 1024 // 500 MB DefaultVisitorAttachmentDailyBandwidthLimit = 500 * 1024 * 1024 // 500 MB
) )
@ -107,8 +107,8 @@ type Config struct {
VisitorRequestExemptIPAddrs []netip.Prefix VisitorRequestExemptIPAddrs []netip.Prefix
VisitorEmailLimitBurst int VisitorEmailLimitBurst int
VisitorEmailLimitReplenish time.Duration VisitorEmailLimitReplenish time.Duration
VisitorAccountCreateLimitBurst int VisitorAccountCreationLimitBurst int
VisitorAccountCreateLimitReplenish time.Duration VisitorAccountCreationLimitReplenish time.Duration
VisitorStatsResetTime time.Time // Time of the day at which to reset visitor stats VisitorStatsResetTime time.Time // Time of the day at which to reset visitor stats
BehindProxy bool BehindProxy bool
StripeSecretKey string StripeSecretKey string
@ -173,8 +173,8 @@ func NewConfig() *Config {
VisitorRequestExemptIPAddrs: make([]netip.Prefix, 0), VisitorRequestExemptIPAddrs: make([]netip.Prefix, 0),
VisitorEmailLimitBurst: DefaultVisitorEmailLimitBurst, VisitorEmailLimitBurst: DefaultVisitorEmailLimitBurst,
VisitorEmailLimitReplenish: DefaultVisitorEmailLimitReplenish, VisitorEmailLimitReplenish: DefaultVisitorEmailLimitReplenish,
VisitorAccountCreateLimitBurst: DefaultVisitorAccountCreateLimitBurst, VisitorAccountCreationLimitBurst: DefaultVisitorAccountCreationLimitBurst,
VisitorAccountCreateLimitReplenish: DefaultVisitorAccountCreateLimitReplenish, VisitorAccountCreationLimitReplenish: DefaultVisitorAccountCreationLimitReplenish,
VisitorStatsResetTime: DefaultVisitorStatsResetTime, VisitorStatsResetTime: DefaultVisitorStatsResetTime,
BehindProxy: false, BehindProxy: false,
StripeSecretKey: "", StripeSecretKey: "",

View File

@ -40,6 +40,8 @@ TODO
- HIGH Rate limiting: dailyLimitToRate is wrong? + TESTS - HIGH Rate limiting: dailyLimitToRate is wrong? + TESTS
- HIGH Rate limiting: Sensitive endpoints (account/login/change-password/...) - HIGH Rate limiting: Sensitive endpoints (account/login/change-password/...)
- HIGH Rate limiting: Delete visitor when tier is changed to refresh rate limiters
- HIGH Rate limiting: When ResetStats() is run, reset messagesLimiter (and others)?
- MEDIUM: Races with v.user (see publishSyncEventAsync test) - MEDIUM: Races with v.user (see publishSyncEventAsync test)
- MEDIUM: Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben) - MEDIUM: Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben)
- MEDIUM: Reservation (UI): Ask for confirmation when removing reservation (deadcade) - MEDIUM: Reservation (UI): Ask for confirmation when removing reservation (deadcade)
@ -50,8 +52,6 @@ TODO
Limits & rate limiting: Limits & rate limiting:
users without tier: should the stats be persisted? are they meaningful? -> test that the visitor is based on the IP address! users without tier: should the stats be persisted? are they meaningful? -> test that the visitor is based on the IP address!
when ResetStats() is run, reset messagesLimiter (and others)?
Delete visitor when tier is changed to refresh rate limiters
Make sure account endpoints make sense for admins Make sure account endpoints make sense for admins
@ -1602,9 +1602,7 @@ func (s *Server) visitor(r *http.Request) (v *visitor, err error) {
} else { } else {
v = s.visitorFromIP(ip) v = s.visitorFromIP(ip)
} }
v.mu.Lock() v.SetUser(u) // Update visitor user with latest from database!
v.user = u
v.mu.Unlock()
return v, err // Always return visitor, even when error occurs! return v, err // Always return visitor, even when error occurs!
} }

View File

@ -31,7 +31,7 @@ func (s *Server) handleAccountCreate(w http.ResponseWriter, r *http.Request, v *
if existingUser, _ := s.userManager.User(newAccount.Username); existingUser != nil { if existingUser, _ := s.userManager.User(newAccount.Username); existingUser != nil {
return errHTTPConflictUserExists return errHTTPConflictUserExists
} }
if v.accountLimiter != nil && !v.accountLimiter.Allow() { if err := v.AccountCreationAllowed(); err != nil {
return errHTTPTooManyRequestsLimitAccountCreation return errHTTPTooManyRequestsLimitAccountCreation
} }
if err := s.userManager.AddUser(newAccount.Username, newAccount.Password, user.RoleUser); err != nil { // TODO this should return a User if err := s.userManager.AddUser(newAccount.Username, newAccount.Password, user.RoleUser); err != nil { // TODO this should return a User

View File

@ -6,6 +6,7 @@ import (
"heckel.io/ntfy/user" "heckel.io/ntfy/user"
"heckel.io/ntfy/util" "heckel.io/ntfy/util"
"io" "io"
"strings"
"testing" "testing"
"time" "time"
) )
@ -91,6 +92,20 @@ func TestAccount_Signup_Disabled(t *testing.T) {
require.Equal(t, 40022, toHTTPError(t, rr.Body.String()).Code) require.Equal(t, 40022, toHTTPError(t, rr.Body.String()).Code)
} }
func TestAccount_Signup_Rate_Limit(t *testing.T) {
conf := newTestConfigWithAuthFile(t)
conf.EnableSignup = true
s := newTestServer(t, conf)
for i := 0; i < 3; i++ {
rr := request(t, s, "POST", "/v1/account", fmt.Sprintf(`{"username":"phil%d", "password":"mypass"}`, i), nil)
require.Equal(t, 200, rr.Code, "failed on iteration %d", i)
}
rr := request(t, s, "POST", "/v1/account", `{"username":"notallowed", "password":"mypass"}`, nil)
require.Equal(t, 429, rr.Code)
require.Equal(t, 42906, toHTTPError(t, rr.Body.String()).Code)
}
func TestAccount_Get_Anonymous(t *testing.T) { func TestAccount_Get_Anonymous(t *testing.T) {
conf := newTestConfigWithAuthFile(t) conf := newTestConfigWithAuthFile(t)
conf.VisitorRequestLimitReplenish = 86 * time.Second conf.VisitorRequestLimitReplenish = 86 * time.Second
@ -567,3 +582,60 @@ func TestAccount_Reservation_Add_Kills_Other_Subscribers(t *testing.T) {
s.topics["mytopic"].CancelSubscribers("<invalid>") s.topics["mytopic"].CancelSubscribers("<invalid>")
<-userCh <-userCh
} }
func TestAccount_Tier_Create(t *testing.T) {
conf := newTestConfigWithAuthFile(t)
s := newTestServer(t, conf)
// Create tier and user
require.Nil(t, s.userManager.CreateTier(&user.Tier{
Code: "pro",
Name: "Pro",
MessagesLimit: 123,
MessagesExpiryDuration: 86400 * time.Second,
EmailsLimit: 32,
ReservationsLimit: 2,
AttachmentFileSizeLimit: 1231231,
AttachmentTotalSizeLimit: 123123,
AttachmentExpiryDuration: 10800 * time.Second,
AttachmentBandwidthLimit: 21474836480,
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
ti, err := s.userManager.Tier("pro")
require.Nil(t, err)
u, err := s.userManager.User("phil")
require.Nil(t, err)
// These are populated by different SQL queries
require.Equal(t, ti, u.Tier)
// Fields
require.True(t, strings.HasPrefix(ti.ID, "ti_"))
require.Equal(t, "pro", ti.Code)
require.Equal(t, "Pro", ti.Name)
require.Equal(t, int64(123), ti.MessagesLimit)
require.Equal(t, 86400*time.Second, ti.MessagesExpiryDuration)
require.Equal(t, int64(32), ti.EmailsLimit)
require.Equal(t, int64(2), ti.ReservationsLimit)
require.Equal(t, int64(1231231), ti.AttachmentFileSizeLimit)
require.Equal(t, int64(123123), ti.AttachmentTotalSizeLimit)
require.Equal(t, 10800*time.Second, ti.AttachmentExpiryDuration)
require.Equal(t, int64(21474836480), ti.AttachmentBandwidthLimit)
}
func TestAccount_Tier_Create_With_ID(t *testing.T) {
conf := newTestConfigWithAuthFile(t)
s := newTestServer(t, conf)
require.Nil(t, s.userManager.CreateTier(&user.Tier{
ID: "ti_123",
Code: "pro",
}))
ti, err := s.userManager.Tier("pro")
require.Nil(t, err)
require.Equal(t, "ti_123", ti.ID)
}

View File

@ -133,6 +133,7 @@ func TestPayments_SubscriptionCreate_NotAStripeCustomer_Success(t *testing.T) {
// Create tier and user // Create tier and user
require.Nil(t, s.userManager.CreateTier(&user.Tier{ require.Nil(t, s.userManager.CreateTier(&user.Tier{
ID: "ti_123",
Code: "pro", Code: "pro",
StripePriceID: "price_123", StripePriceID: "price_123",
})) }))
@ -168,6 +169,7 @@ func TestPayments_SubscriptionCreate_StripeCustomer_Success(t *testing.T) {
// Create tier and user // Create tier and user
require.Nil(t, s.userManager.CreateTier(&user.Tier{ require.Nil(t, s.userManager.CreateTier(&user.Tier{
ID: "ti_123",
Code: "pro", Code: "pro",
StripePriceID: "price_123", StripePriceID: "price_123",
})) }))
@ -209,6 +211,7 @@ func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) {
// Create tier and user // Create tier and user
require.Nil(t, s.userManager.CreateTier(&user.Tier{ require.Nil(t, s.userManager.CreateTier(&user.Tier{
ID: "ti_123",
Code: "pro", Code: "pro",
StripePriceID: "price_123", StripePriceID: "price_123",
})) }))
@ -235,6 +238,106 @@ func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) {
require.Equal(t, 401, rr.Code) require.Equal(t, 401, rr.Code)
} }
func TestPayments_Checkout_Success_And_Increase_Ratelimits_Reset_Visitor(t *testing.T) {
// This tests a successful checkout flow (not a paying customer -> paying customer),
// and also tests that during the upgrade we are RESETTING THE RATE LIMITS of the existing user.
stripeMock := &testStripeAPI{}
defer stripeMock.AssertExpectations(t)
c := newTestConfigWithAuthFile(t)
c.StripeSecretKey = "secret key"
c.StripeWebhookKey = "webhook key"
c.VisitorRequestLimitBurst = 10
c.VisitorRequestLimitReplenish = time.Hour
s := newTestServer(t, c)
s.stripe = stripeMock
// Create a user with a Stripe subscription and 3 reservations
require.Nil(t, s.userManager.CreateTier(&user.Tier{
ID: "ti_123",
Code: "starter",
StripePriceID: "price_1234",
ReservationsLimit: 1,
MessagesLimit: 100,
MessagesExpiryDuration: time.Hour,
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) // No tier
u, err := s.userManager.User("phil")
require.Nil(t, err)
// Define how the mock should react
stripeMock.
On("GetSession", "SOMETOKEN").
Return(&stripe.CheckoutSession{
ClientReferenceID: u.ID, // ntfy user ID
Customer: &stripe.Customer{
ID: "acct_5555",
},
Subscription: &stripe.Subscription{
ID: "sub_1234",
},
}, nil)
stripeMock.
On("GetSubscription", "sub_1234").
Return(&stripe.Subscription{
ID: "sub_1234",
Status: stripe.SubscriptionStatusActive,
CurrentPeriodEnd: 123456789,
CancelAt: 0,
Items: &stripe.SubscriptionItemList{
Data: []*stripe.SubscriptionItem{
{
Price: &stripe.Price{ID: "price_1234"},
},
},
},
}, nil)
stripeMock.
On("UpdateCustomer", mock.Anything).
Return(&stripe.Customer{}, nil)
// Send messages until rate limit of free tier is hit
for i := 0; i < 10; i++ {
rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
}
rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 429, rr.Code)
// Simulate Stripe success return URL call (no user context)
rr = request(t, s, "GET", "/v1/account/billing/subscription/success/SOMETOKEN", "", nil)
require.Equal(t, 303, rr.Code)
// Verify that database columns were updated
u, err = s.userManager.User("phil")
require.Nil(t, err)
require.Equal(t, "starter", u.Tier.Code) // Not "pro"
require.Equal(t, "acct_5555", u.Billing.StripeCustomerID)
require.Equal(t, "sub_1234", u.Billing.StripeSubscriptionID)
require.Equal(t, stripe.SubscriptionStatusActive, u.Billing.StripeSubscriptionStatus)
require.Equal(t, int64(123456789), u.Billing.StripeSubscriptionPaidUntil.Unix())
require.Equal(t, int64(0), u.Billing.StripeSubscriptionCancelAt.Unix())
// FIXME FIXME This test is broken, because the rate limit logic is unclear!
// Now for the fun part: Verify that new rate limits are immediately applied
for i := 0; i < 100; i++ {
rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code, "failed on iteration %d", i)
}
rr = request(t, s, "PUT", "/mytopic", "some message", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 429, rr.Code)
}
func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(t *testing.T) { func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(t *testing.T) {
// This tests incoming webhooks from Stripe to update a subscription: // This tests incoming webhooks from Stripe to update a subscription:
// - All Stripe columns are updated in the user table // - All Stripe columns are updated in the user table
@ -257,6 +360,7 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
// Create a user with a Stripe subscription and 3 reservations // Create a user with a Stripe subscription and 3 reservations
require.Nil(t, s.userManager.CreateTier(&user.Tier{ require.Nil(t, s.userManager.CreateTier(&user.Tier{
ID: "ti_1",
Code: "starter", Code: "starter",
StripePriceID: "price_1234", // ! StripePriceID: "price_1234", // !
ReservationsLimit: 1, // ! ReservationsLimit: 1, // !
@ -268,6 +372,7 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
AttachmentBandwidthLimit: 1000000, AttachmentBandwidthLimit: 1000000,
})) }))
require.Nil(t, s.userManager.CreateTier(&user.Tier{ require.Nil(t, s.userManager.CreateTier(&user.Tier{
ID: "ti_2",
Code: "pro", Code: "pro",
StripePriceID: "price_1111", // ! StripePriceID: "price_1111", // !
ReservationsLimit: 3, // ! ReservationsLimit: 3, // !

View File

@ -3,6 +3,7 @@ package server
import ( import (
"errors" "errors"
"fmt" "fmt"
"heckel.io/ntfy/log"
"heckel.io/ntfy/user" "heckel.io/ntfy/user"
"net/netip" "net/netip"
"sync" "sync"
@ -41,7 +42,7 @@ type visitor struct {
emailsLimiter *rate.Limiter // Rate limiter for emails emailsLimiter *rate.Limiter // Rate limiter for emails
subscriptionLimiter util.Limiter // Fixed limiter for active subscriptions (ongoing connections) subscriptionLimiter util.Limiter // Fixed limiter for active subscriptions (ongoing connections)
bandwidthLimiter util.Limiter // Limiter for attachment bandwidth downloads bandwidthLimiter util.Limiter // Limiter for attachment bandwidth downloads
accountLimiter *rate.Limiter // Rate limiter for account creation accountLimiter *rate.Limiter // Rate limiter for account creation, may be nil
firebase time.Time // Next allowed Firebase message firebase time.Time // Next allowed Firebase message
seen time.Time // Last seen time of this visitor (needed for removal of stale visitors) seen time.Time // Last seen time of this visitor (needed for removal of stale visitors)
mu sync.Mutex mu sync.Mutex
@ -85,26 +86,12 @@ const (
) )
func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Manager, ip netip.Addr, user *user.User) *visitor { func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Manager, ip netip.Addr, user *user.User) *visitor {
var messagesLimiter, attachmentBandwidthLimiter util.Limiter
var requestLimiter, emailsLimiter, accountLimiter *rate.Limiter
var messages, emails int64 var messages, emails int64
if user != nil { if user != nil {
messages = user.Stats.Messages messages = user.Stats.Messages
emails = user.Stats.Emails emails = user.Stats.Emails
} else {
accountLimiter = rate.NewLimiter(rate.Every(conf.VisitorAccountCreateLimitReplenish), conf.VisitorAccountCreateLimitBurst)
} }
if user != nil && user.Tier != nil { v := &visitor{
requestLimiter = rate.NewLimiter(dailyLimitToRate(user.Tier.MessagesLimit), conf.VisitorRequestLimitBurst)
messagesLimiter = util.NewFixedLimiter(user.Tier.MessagesLimit)
emailsLimiter = rate.NewLimiter(dailyLimitToRate(user.Tier.EmailsLimit), conf.VisitorEmailLimitBurst)
attachmentBandwidthLimiter = util.NewBytesLimiter(int(user.Tier.AttachmentBandwidthLimit), 24*time.Hour)
} else {
requestLimiter = rate.NewLimiter(rate.Every(conf.VisitorRequestLimitReplenish), conf.VisitorRequestLimitBurst)
emailsLimiter = rate.NewLimiter(rate.Every(conf.VisitorEmailLimitReplenish), conf.VisitorEmailLimitBurst)
attachmentBandwidthLimiter = util.NewBytesLimiter(int(conf.VisitorAttachmentDailyBandwidthLimit), 24*time.Hour)
}
return &visitor{
config: conf, config: conf,
messageCache: messageCache, messageCache: messageCache,
userManager: userManager, // May be nil userManager: userManager, // May be nil
@ -112,20 +99,26 @@ func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Mana
user: user, user: user,
messages: messages, messages: messages,
emails: emails, emails: emails,
requestLimiter: requestLimiter,
messagesLimiter: messagesLimiter, // May be nil
emailsLimiter: emailsLimiter,
subscriptionLimiter: util.NewFixedLimiter(int64(conf.VisitorSubscriptionLimit)),
bandwidthLimiter: attachmentBandwidthLimiter,
accountLimiter: accountLimiter, // May be nil
firebase: time.Unix(0, 0), firebase: time.Unix(0, 0),
seen: time.Now(), seen: time.Now(),
subscriptionLimiter: util.NewFixedLimiter(int64(conf.VisitorSubscriptionLimit)),
requestLimiter: nil, // Set in resetLimiters
messagesLimiter: nil, // Set in resetLimiters, may be nil
emailsLimiter: nil, // Set in resetLimiters
bandwidthLimiter: nil, // Set in resetLimiters
accountLimiter: nil, // Set in resetLimiters, may be nil
} }
v.resetLimiters()
return v
} }
func (v *visitor) String() string { func (v *visitor) String() string {
v.mu.Lock() v.mu.Lock()
defer v.mu.Unlock() defer v.mu.Unlock()
return v.stringNoLock()
}
func (v *visitor) stringNoLock() string {
if v.user != nil && v.user.Billing.StripeCustomerID != "" { if v.user != nil && v.user.Billing.StripeCustomerID != "" {
return fmt.Sprintf("%s/%s/%s", v.ip.String(), v.user.ID, v.user.Billing.StripeCustomerID) return fmt.Sprintf("%s/%s/%s", v.ip.String(), v.user.ID, v.user.Billing.StripeCustomerID)
} else if v.user != nil { } else if v.user != nil {
@ -179,6 +172,13 @@ func (v *visitor) SubscriptionAllowed() error {
return nil return nil
} }
func (v *visitor) AccountCreationAllowed() error {
if v.accountLimiter != nil && !v.accountLimiter.Allow() {
return errVisitorLimitReached
}
return nil
}
func (v *visitor) RemoveSubscription() { func (v *visitor) RemoveSubscription() {
v.mu.Lock() v.mu.Lock()
defer v.mu.Unlock() defer v.mu.Unlock()
@ -235,7 +235,35 @@ func (v *visitor) ResetStats() {
func (v *visitor) SetUser(u *user.User) { func (v *visitor) SetUser(u *user.User) {
v.mu.Lock() v.mu.Lock()
defer v.mu.Unlock() defer v.mu.Unlock()
shouldResetLimiters := v.user.TierID() != u.TierID() // TierID works with nil receiver
v.user = u v.user = u
if shouldResetLimiters {
v.resetLimiters()
}
}
func (v *visitor) resetLimiters() {
log.Info("%s Resetting limiters for visitor", v.stringNoLock())
var messagesLimiter, bandwidthLimiter util.Limiter
var requestLimiter, emailsLimiter, accountLimiter *rate.Limiter
if v.user != nil && v.user.Tier != nil {
requestLimiter = rate.NewLimiter(dailyLimitToRate(v.user.Tier.MessagesLimit), v.config.VisitorRequestLimitBurst)
messagesLimiter = util.NewFixedLimiter(v.user.Tier.MessagesLimit)
emailsLimiter = rate.NewLimiter(dailyLimitToRate(v.user.Tier.EmailsLimit), v.config.VisitorEmailLimitBurst)
bandwidthLimiter = util.NewBytesLimiter(int(v.user.Tier.AttachmentBandwidthLimit), 24*time.Hour)
accountLimiter = nil // A logged-in user cannot create an account
} else {
requestLimiter = rate.NewLimiter(rate.Every(v.config.VisitorRequestLimitReplenish), v.config.VisitorRequestLimitBurst)
messagesLimiter = nil // Message limit is governed by the requestLimiter
emailsLimiter = rate.NewLimiter(rate.Every(v.config.VisitorEmailLimitReplenish), v.config.VisitorEmailLimitBurst)
bandwidthLimiter = util.NewBytesLimiter(int(v.config.VisitorAttachmentDailyBandwidthLimit), 24*time.Hour)
accountLimiter = rate.NewLimiter(rate.Every(v.config.VisitorAccountCreationLimitReplenish), v.config.VisitorAccountCreationLimitBurst)
}
v.requestLimiter = requestLimiter
v.messagesLimiter = messagesLimiter
v.emailsLimiter = emailsLimiter
v.bandwidthLimiter = bandwidthLimiter
v.accountLimiter = accountLimiter
} }
// MaybeUserID returns the user ID of the visitor (if any). If this is an anonymous visitor, // MaybeUserID returns the user ID of the visitor (if any). If this is an anonymous visitor,

View File

@ -110,26 +110,26 @@ const (
` `
selectUserByIDQuery = ` selectUserByIDQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
FROM user u FROM user u
LEFT JOIN tier t on t.id = u.tier_id LEFT JOIN tier t on t.id = u.tier_id
WHERE u.id = ? WHERE u.id = ?
` `
selectUserByNameQuery = ` selectUserByNameQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
FROM user u FROM user u
LEFT JOIN tier t on t.id = u.tier_id LEFT JOIN tier t on t.id = u.tier_id
WHERE user = ? WHERE user = ?
` `
selectUserByTokenQuery = ` selectUserByTokenQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
FROM user u FROM user u
JOIN user_token t on u.id = t.user_id JOIN user_token t on u.id = t.user_id
LEFT JOIN tier t on t.id = u.tier_id LEFT JOIN tier t on t.id = u.tier_id
WHERE t.token = ? AND t.expires >= ? WHERE t.token = ? AND t.expires >= ?
` `
selectUserByStripeCustomerIDQuery = ` selectUserByStripeCustomerIDQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
FROM user u FROM user u
LEFT JOIN tier t on t.id = u.tier_id LEFT JOIN tier t on t.id = u.tier_id
WHERE u.stripe_customer_id = ? WHERE u.stripe_customer_id = ?
@ -669,13 +669,13 @@ func (a *Manager) userByToken(token string) (*User, error) {
func (a *Manager) readUser(rows *sql.Rows) (*User, error) { func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
defer rows.Close() defer rows.Close()
var id, username, hash, role, prefs, syncTopic string var id, username, hash, role, prefs, syncTopic string
var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripePriceID, tierCode, tierName sql.NullString var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripePriceID, tierID, tierCode, tierName sql.NullString
var messages, emails int64 var messages, emails int64
var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64 var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64
if !rows.Next() { if !rows.Next() {
return nil, ErrUserNotFound return nil, ErrUserNotFound
} }
if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripePriceID); err != nil { if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierID, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripePriceID); err != nil {
return nil, err return nil, err
} else if err := rows.Err(); err != nil { } else if err := rows.Err(); err != nil {
return nil, err return nil, err
@ -706,6 +706,7 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
if tierCode.Valid { if tierCode.Valid {
// See readTier() when this is changed! // See readTier() when this is changed!
user.Tier = &Tier{ user.Tier = &Tier{
ID: tierID.String,
Code: tierCode.String, Code: tierCode.String,
Name: tierName.String, Name: tierName.String,
MessagesLimit: messagesLimit.Int64, MessagesLimit: messagesLimit.Int64,
@ -995,8 +996,10 @@ func (a *Manager) DefaultAccess() Permission {
// CreateTier creates a new tier in the database // CreateTier creates a new tier in the database
func (a *Manager) CreateTier(tier *Tier) error { func (a *Manager) CreateTier(tier *Tier) error {
tierID := util.RandomStringPrefix(tierIDPrefix, tierIDLength) if tier.ID == "" {
if _, err := a.db.Exec(insertTierQuery, tierID, tier.Code, tier.Name, tier.MessagesLimit, int64(tier.MessagesExpiryDuration.Seconds()), tier.EmailsLimit, tier.ReservationsLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, tier.StripePriceID); err != nil { tier.ID = util.RandomStringPrefix(tierIDPrefix, tierIDLength)
}
if _, err := a.db.Exec(insertTierQuery, tier.ID, tier.Code, tier.Name, tier.MessagesLimit, int64(tier.MessagesExpiryDuration.Seconds()), tier.EmailsLimit, tier.ReservationsLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, tier.StripePriceID); err != nil {
return err return err
} }
return nil return nil

View File

@ -23,6 +23,15 @@ type User struct {
Deleted bool Deleted bool
} }
// TierID returns the ID of the User.Tier, or an empty string if the user has no tier,
// or if the user itself is nil.
func (u *User) TierID() string {
if u == nil || u.Tier == nil {
return ""
}
return u.Tier.ID
}
// Auther is an interface for authentication and authorization // Auther is an interface for authentication and authorization
type Auther interface { type Auther interface {
// Authenticate checks username and password and returns a user if correct. The method // Authenticate checks username and password and returns a user if correct. The method