Add limiters and database changes

pull/717/head
binwiederhier 2023-05-07 11:59:15 -04:00
parent 113b7c8a08
commit f9e2d6ddcb
11 changed files with 173 additions and 32 deletions

View File

@ -83,6 +83,8 @@ var flagsServe = append(
altsrc.NewStringFlag(&cli.StringFlag{Name: "visitor-request-limit-exempt-hosts", Aliases: []string{"visitor_request_limit_exempt_hosts"}, EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_EXEMPT_HOSTS"}, Value: "", Usage: "hostnames and/or IP addresses of hosts that will be exempt from the visitor request limit"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "visitor-request-limit-exempt-hosts", Aliases: []string{"visitor_request_limit_exempt_hosts"}, EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_EXEMPT_HOSTS"}, Value: "", Usage: "hostnames and/or IP addresses of hosts that will be exempt from the visitor request limit"}),
altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-message-daily-limit", Aliases: []string{"visitor_message_daily_limit"}, EnvVars: []string{"NTFY_VISITOR_MESSAGE_DAILY_LIMIT"}, Value: server.DefaultVisitorMessageDailyLimit, Usage: "max messages per visitor per day, derived from request limit if unset"}), altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-message-daily-limit", Aliases: []string{"visitor_message_daily_limit"}, EnvVars: []string{"NTFY_VISITOR_MESSAGE_DAILY_LIMIT"}, Value: server.DefaultVisitorMessageDailyLimit, Usage: "max messages per visitor per day, derived from request limit if unset"}),
altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-email-limit-burst", Aliases: []string{"visitor_email_limit_burst"}, EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_BURST"}, Value: server.DefaultVisitorEmailLimitBurst, Usage: "initial limit of e-mails per visitor"}), altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-email-limit-burst", Aliases: []string{"visitor_email_limit_burst"}, EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_BURST"}, Value: server.DefaultVisitorEmailLimitBurst, Usage: "initial limit of e-mails per visitor"}),
altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-sms-daily-limit", Aliases: []string{"visitor_sms_daily_limit"}, EnvVars: []string{"NTFY_VISITOR_SMS_DAILY_LIMIT"}, Value: server.DefaultVisitorSMSDailyLimit, Usage: "max number of SMS messages per visitor per day"}),
altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-call-daily-limit", Aliases: []string{"visitor_call_daily_limit"}, EnvVars: []string{"NTFY_VISITOR_CALL_DAILY_LIMIT"}, Value: server.DefaultVisitorCallDailyLimit, Usage: "max number of phone calls per visitor per day"}),
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "visitor-email-limit-replenish", Aliases: []string{"visitor_email_limit_replenish"}, EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_REPLENISH"}, Value: server.DefaultVisitorEmailLimitReplenish, Usage: "interval at which burst limit is replenished (one per x)"}), altsrc.NewDurationFlag(&cli.DurationFlag{Name: "visitor-email-limit-replenish", Aliases: []string{"visitor_email_limit_replenish"}, EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_REPLENISH"}, Value: server.DefaultVisitorEmailLimitReplenish, Usage: "interval at which burst limit is replenished (one per x)"}),
altsrc.NewBoolFlag(&cli.BoolFlag{Name: "visitor-subscriber-rate-limiting", Aliases: []string{"visitor_subscriber_rate_limiting"}, EnvVars: []string{"NTFY_VISITOR_SUBSCRIBER_RATE_LIMITING"}, Value: false, Usage: "enables subscriber-based rate limiting"}), altsrc.NewBoolFlag(&cli.BoolFlag{Name: "visitor-subscriber-rate-limiting", Aliases: []string{"visitor_subscriber_rate_limiting"}, EnvVars: []string{"NTFY_VISITOR_SUBSCRIBER_RATE_LIMITING"}, Value: false, Usage: "enables subscriber-based rate limiting"}),
altsrc.NewBoolFlag(&cli.BoolFlag{Name: "behind-proxy", Aliases: []string{"behind_proxy", "P"}, EnvVars: []string{"NTFY_BEHIND_PROXY"}, Value: false, Usage: "if set, use X-Forwarded-For header to determine visitor IP address (for rate limiting)"}), altsrc.NewBoolFlag(&cli.BoolFlag{Name: "behind-proxy", Aliases: []string{"behind_proxy", "P"}, EnvVars: []string{"NTFY_BEHIND_PROXY"}, Value: false, Usage: "if set, use X-Forwarded-For header to determine visitor IP address (for rate limiting)"}),
@ -168,6 +170,8 @@ func execServe(c *cli.Context) error {
visitorMessageDailyLimit := c.Int("visitor-message-daily-limit") visitorMessageDailyLimit := c.Int("visitor-message-daily-limit")
visitorEmailLimitBurst := c.Int("visitor-email-limit-burst") visitorEmailLimitBurst := c.Int("visitor-email-limit-burst")
visitorEmailLimitReplenish := c.Duration("visitor-email-limit-replenish") visitorEmailLimitReplenish := c.Duration("visitor-email-limit-replenish")
visitorSMSDailyLimit := c.Int("visitor-sms-daily-limit")
visitorCallDailyLimit := c.Int("visitor-call-daily-limit")
behindProxy := c.Bool("behind-proxy") behindProxy := c.Bool("behind-proxy")
stripeSecretKey := c.String("stripe-secret-key") stripeSecretKey := c.String("stripe-secret-key")
stripeWebhookKey := c.String("stripe-webhook-key") stripeWebhookKey := c.String("stripe-webhook-key")
@ -329,6 +333,8 @@ func execServe(c *cli.Context) error {
conf.VisitorMessageDailyLimit = visitorMessageDailyLimit conf.VisitorMessageDailyLimit = visitorMessageDailyLimit
conf.VisitorEmailLimitBurst = visitorEmailLimitBurst conf.VisitorEmailLimitBurst = visitorEmailLimitBurst
conf.VisitorEmailLimitReplenish = visitorEmailLimitReplenish conf.VisitorEmailLimitReplenish = visitorEmailLimitReplenish
conf.VisitorSMSDailyLimit = visitorSMSDailyLimit
conf.VisitorCallDailyLimit = visitorCallDailyLimit
conf.VisitorSubscriberRateLimiting = visitorSubscriberRateLimiting conf.VisitorSubscriberRateLimiting = visitorSubscriberRateLimiting
conf.BehindProxy = behindProxy conf.BehindProxy = behindProxy
conf.StripeSecretKey = stripeSecretKey conf.StripeSecretKey = stripeSecretKey

View File

@ -18,6 +18,8 @@ const (
defaultMessageLimit = 5000 defaultMessageLimit = 5000
defaultMessageExpiryDuration = "12h" defaultMessageExpiryDuration = "12h"
defaultEmailLimit = 20 defaultEmailLimit = 20
defaultSMSLimit = 10
defaultCallLimit = 10
defaultReservationLimit = 3 defaultReservationLimit = 3
defaultAttachmentFileSizeLimit = "15M" defaultAttachmentFileSizeLimit = "15M"
defaultAttachmentTotalSizeLimit = "100M" defaultAttachmentTotalSizeLimit = "100M"
@ -48,6 +50,8 @@ var cmdTier = &cli.Command{
&cli.Int64Flag{Name: "message-limit", Value: defaultMessageLimit, Usage: "daily message limit"}, &cli.Int64Flag{Name: "message-limit", Value: defaultMessageLimit, Usage: "daily message limit"},
&cli.StringFlag{Name: "message-expiry-duration", Value: defaultMessageExpiryDuration, Usage: "duration after which messages are deleted"}, &cli.StringFlag{Name: "message-expiry-duration", Value: defaultMessageExpiryDuration, Usage: "duration after which messages are deleted"},
&cli.Int64Flag{Name: "email-limit", Value: defaultEmailLimit, Usage: "daily email limit"}, &cli.Int64Flag{Name: "email-limit", Value: defaultEmailLimit, Usage: "daily email limit"},
&cli.Int64Flag{Name: "sms-limit", Value: defaultSMSLimit, Usage: "daily SMS limit"},
&cli.Int64Flag{Name: "call-limit", Value: defaultCallLimit, Usage: "daily phone call limit"},
&cli.Int64Flag{Name: "reservation-limit", Value: defaultReservationLimit, Usage: "topic reservation limit"}, &cli.Int64Flag{Name: "reservation-limit", Value: defaultReservationLimit, Usage: "topic reservation limit"},
&cli.StringFlag{Name: "attachment-file-size-limit", Value: defaultAttachmentFileSizeLimit, Usage: "per-attachment file size limit"}, &cli.StringFlag{Name: "attachment-file-size-limit", Value: defaultAttachmentFileSizeLimit, Usage: "per-attachment file size limit"},
&cli.StringFlag{Name: "attachment-total-size-limit", Value: defaultAttachmentTotalSizeLimit, Usage: "total size limit of attachments for the user"}, &cli.StringFlag{Name: "attachment-total-size-limit", Value: defaultAttachmentTotalSizeLimit, Usage: "total size limit of attachments for the user"},
@ -91,6 +95,8 @@ Examples:
&cli.Int64Flag{Name: "message-limit", Usage: "daily message limit"}, &cli.Int64Flag{Name: "message-limit", Usage: "daily message limit"},
&cli.StringFlag{Name: "message-expiry-duration", Usage: "duration after which messages are deleted"}, &cli.StringFlag{Name: "message-expiry-duration", Usage: "duration after which messages are deleted"},
&cli.Int64Flag{Name: "email-limit", Usage: "daily email limit"}, &cli.Int64Flag{Name: "email-limit", Usage: "daily email limit"},
&cli.Int64Flag{Name: "sms-limit", Usage: "daily SMS limit"},
&cli.Int64Flag{Name: "call-limit", Usage: "daily phone call limit"},
&cli.Int64Flag{Name: "reservation-limit", Usage: "topic reservation limit"}, &cli.Int64Flag{Name: "reservation-limit", Usage: "topic reservation limit"},
&cli.StringFlag{Name: "attachment-file-size-limit", Usage: "per-attachment file size limit"}, &cli.StringFlag{Name: "attachment-file-size-limit", Usage: "per-attachment file size limit"},
&cli.StringFlag{Name: "attachment-total-size-limit", Usage: "total size limit of attachments for the user"}, &cli.StringFlag{Name: "attachment-total-size-limit", Usage: "total size limit of attachments for the user"},
@ -215,6 +221,8 @@ func execTierAdd(c *cli.Context) error {
MessageLimit: c.Int64("message-limit"), MessageLimit: c.Int64("message-limit"),
MessageExpiryDuration: messageExpiryDuration, MessageExpiryDuration: messageExpiryDuration,
EmailLimit: c.Int64("email-limit"), EmailLimit: c.Int64("email-limit"),
SMSLimit: c.Int64("sms-limit"),
CallLimit: c.Int64("call-limit"),
ReservationLimit: c.Int64("reservation-limit"), ReservationLimit: c.Int64("reservation-limit"),
AttachmentFileSizeLimit: attachmentFileSizeLimit, AttachmentFileSizeLimit: attachmentFileSizeLimit,
AttachmentTotalSizeLimit: attachmentTotalSizeLimit, AttachmentTotalSizeLimit: attachmentTotalSizeLimit,
@ -267,6 +275,12 @@ func execTierChange(c *cli.Context) error {
if c.IsSet("email-limit") { if c.IsSet("email-limit") {
tier.EmailLimit = c.Int64("email-limit") tier.EmailLimit = c.Int64("email-limit")
} }
if c.IsSet("sms-limit") {
tier.SMSLimit = c.Int64("sms-limit")
}
if c.IsSet("call-limit") {
tier.CallLimit = c.Int64("call-limit")
}
if c.IsSet("reservation-limit") { if c.IsSet("reservation-limit") {
tier.ReservationLimit = c.Int64("reservation-limit") tier.ReservationLimit = c.Int64("reservation-limit")
} }
@ -357,6 +371,8 @@ func printTier(c *cli.Context, tier *user.Tier) {
fmt.Fprintf(c.App.ErrWriter, "- Message limit: %d\n", tier.MessageLimit) fmt.Fprintf(c.App.ErrWriter, "- Message limit: %d\n", tier.MessageLimit)
fmt.Fprintf(c.App.ErrWriter, "- Message expiry duration: %s (%d seconds)\n", tier.MessageExpiryDuration.String(), int64(tier.MessageExpiryDuration.Seconds())) fmt.Fprintf(c.App.ErrWriter, "- Message expiry duration: %s (%d seconds)\n", tier.MessageExpiryDuration.String(), int64(tier.MessageExpiryDuration.Seconds()))
fmt.Fprintf(c.App.ErrWriter, "- Email limit: %d\n", tier.EmailLimit) fmt.Fprintf(c.App.ErrWriter, "- Email limit: %d\n", tier.EmailLimit)
fmt.Fprintf(c.App.ErrWriter, "- SMS limit: %d\n", tier.SMSLimit)
fmt.Fprintf(c.App.ErrWriter, "- Phone call limit: %d\n", tier.CallLimit)
fmt.Fprintf(c.App.ErrWriter, "- Reservation limit: %d\n", tier.ReservationLimit) fmt.Fprintf(c.App.ErrWriter, "- Reservation limit: %d\n", tier.ReservationLimit)
fmt.Fprintf(c.App.ErrWriter, "- Attachment file size limit: %s\n", util.FormatSize(tier.AttachmentFileSizeLimit)) fmt.Fprintf(c.App.ErrWriter, "- Attachment file size limit: %s\n", util.FormatSize(tier.AttachmentFileSizeLimit))
fmt.Fprintf(c.App.ErrWriter, "- Attachment total size limit: %s\n", util.FormatSize(tier.AttachmentTotalSizeLimit)) fmt.Fprintf(c.App.ErrWriter, "- Attachment total size limit: %s\n", util.FormatSize(tier.AttachmentTotalSizeLimit))

View File

@ -47,6 +47,8 @@ const (
DefaultVisitorMessageDailyLimit = 0 DefaultVisitorMessageDailyLimit = 0
DefaultVisitorEmailLimitBurst = 16 DefaultVisitorEmailLimitBurst = 16
DefaultVisitorEmailLimitReplenish = time.Hour DefaultVisitorEmailLimitReplenish = time.Hour
DefaultVisitorSMSDailyLimit = 10
DefaultVisitorCallDailyLimit = 10
DefaultVisitorAccountCreationLimitBurst = 3 DefaultVisitorAccountCreationLimitBurst = 3
DefaultVisitorAccountCreationLimitReplenish = 24 * time.Hour DefaultVisitorAccountCreationLimitReplenish = 24 * time.Hour
DefaultVisitorAuthFailureLimitBurst = 30 DefaultVisitorAuthFailureLimitBurst = 30
@ -126,6 +128,8 @@ type Config struct {
VisitorMessageDailyLimit int VisitorMessageDailyLimit int
VisitorEmailLimitBurst int VisitorEmailLimitBurst int
VisitorEmailLimitReplenish time.Duration VisitorEmailLimitReplenish time.Duration
VisitorSMSDailyLimit int
VisitorCallDailyLimit int
VisitorAccountCreationLimitBurst int VisitorAccountCreationLimitBurst int
VisitorAccountCreationLimitReplenish time.Duration VisitorAccountCreationLimitReplenish time.Duration
VisitorAuthFailureLimitBurst int VisitorAuthFailureLimitBurst int

View File

@ -126,6 +126,8 @@ var (
errHTTPTooManyRequestsLimitReservations = &errHTTP{42907, http.StatusTooManyRequests, "limit reached: too many topic reservations for this user", "", nil} errHTTPTooManyRequestsLimitReservations = &errHTTP{42907, http.StatusTooManyRequests, "limit reached: too many topic reservations for this user", "", nil}
errHTTPTooManyRequestsLimitMessages = &errHTTP{42908, http.StatusTooManyRequests, "limit reached: daily message quota reached", "https://ntfy.sh/docs/publish/#limitations", nil} errHTTPTooManyRequestsLimitMessages = &errHTTP{42908, http.StatusTooManyRequests, "limit reached: daily message quota reached", "https://ntfy.sh/docs/publish/#limitations", nil}
errHTTPTooManyRequestsLimitAuthFailure = &errHTTP{42909, http.StatusTooManyRequests, "limit reached: too many auth failures", "https://ntfy.sh/docs/publish/#limitations", nil} // FIXME document limit errHTTPTooManyRequestsLimitAuthFailure = &errHTTP{42909, http.StatusTooManyRequests, "limit reached: too many auth failures", "https://ntfy.sh/docs/publish/#limitations", nil} // FIXME document limit
errHTTPTooManyRequestsLimitSMS = &errHTTP{42910, http.StatusTooManyRequests, "limit reached: daily SMS quota reached", "https://ntfy.sh/docs/publish/#limitations", nil}
errHTTPTooManyRequestsLimitCalls = &errHTTP{42911, http.StatusTooManyRequests, "limit reached: daily phone call quota reached", "https://ntfy.sh/docs/publish/#limitations", nil}
errHTTPInternalError = &errHTTP{50001, http.StatusInternalServerError, "internal server error", "", nil} errHTTPInternalError = &errHTTP{50001, http.StatusInternalServerError, "internal server error", "", nil}
errHTTPInternalErrorInvalidPath = &errHTTP{50002, http.StatusInternalServerError, "internal server error: invalid path", "", nil} errHTTPInternalErrorInvalidPath = &errHTTP{50002, http.StatusInternalServerError, "internal server error: invalid path", "", nil}
errHTTPInternalErrorMissingBaseURL = &errHTTP{50003, http.StatusInternalServerError, "internal server error: base-url must be be configured for this feature", "https://ntfy.sh/docs/config/", nil} errHTTPInternalErrorMissingBaseURL = &errHTTP{50003, http.StatusInternalServerError, "internal server error: base-url must be be configured for this feature", "https://ntfy.sh/docs/config/", nil}

View File

@ -683,6 +683,10 @@ func (s *Server) handlePublishInternal(r *http.Request, v *visitor) (*message, e
return nil, errHTTPTooManyRequestsLimitMessages.With(t) return nil, errHTTPTooManyRequestsLimitMessages.With(t)
} else if email != "" && !vrate.EmailAllowed() { } else if email != "" && !vrate.EmailAllowed() {
return nil, errHTTPTooManyRequestsLimitEmails.With(t) return nil, errHTTPTooManyRequestsLimitEmails.With(t)
} else if sms != "" && !vrate.SMSAllowed() {
return nil, errHTTPTooManyRequestsLimitSMS.With(t)
} else if call != "" && !vrate.CallAllowed() {
return nil, errHTTPTooManyRequestsLimitCalls.With(t)
} }
if m.PollID != "" { if m.PollID != "" {
m = newPollRequestMessage(t.ID, m.PollID) m = newPollRequestMessage(t.ID, m.PollID)
@ -726,7 +730,7 @@ func (s *Server) handlePublishInternal(r *http.Request, v *visitor) (*message, e
if s.config.TwilioAccount != "" && sms != "" { if s.config.TwilioAccount != "" && sms != "" {
go s.sendSMS(v, r, m, sms) go s.sendSMS(v, r, m, sms)
} }
if call != "" { if s.config.TwilioAccount != "" && call != "" {
go s.callPhone(v, r, m, call) go s.callPhone(v, r, m, call)
} }
if s.config.UpstreamBaseURL != "" { if s.config.UpstreamBaseURL != "" {

View File

@ -224,11 +224,17 @@
# visitor-request-limit-exempt-hosts: "" # visitor-request-limit-exempt-hosts: ""
# Rate limiting: Hard daily limit of messages per visitor and day. The limit is reset # Rate limiting: Hard daily limit of messages per visitor and day. The limit is reset
# every day at midnight UTC. If the limit is not set (or set to zero), the request # every day at midnight UTC. If the limit is not set (or set to zero), the request limit (see above)
# limit (see above) governs the upper limit. # governs the upper limit. SMS and calls are only supported if the twilio-settings are properly configured.
# #
# visitor-message-daily-limit: 0 # visitor-message-daily-limit: 0
# Rate limiting: Daily limit of SMS and calls per visitor and day. The limit is reset every day
# at midnight UTC. SMS and calls are only supported if the twilio-settings are properly configured.
#
# visitor-sms-daily-limit: 10
# visitor-call-daily-limit: 10
# Rate limiting: Allowed emails per visitor: # Rate limiting: Allowed emails per visitor:
# - visitor-email-limit-burst is the initial bucket of emails each visitor has # - visitor-email-limit-burst is the initial bucket of emails each visitor has
# - visitor-email-limit-replenish is the rate at which the bucket is refilled # - visitor-email-limit-replenish is the rate at which the bucket is refilled

View File

@ -56,6 +56,8 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, r *http.Request, v *vis
Messages: limits.MessageLimit, Messages: limits.MessageLimit,
MessagesExpiryDuration: int64(limits.MessageExpiryDuration.Seconds()), MessagesExpiryDuration: int64(limits.MessageExpiryDuration.Seconds()),
Emails: limits.EmailLimit, Emails: limits.EmailLimit,
SMS: limits.SMSLimit,
Calls: limits.CallLimit,
Reservations: limits.ReservationsLimit, Reservations: limits.ReservationsLimit,
AttachmentTotalSize: limits.AttachmentTotalSizeLimit, AttachmentTotalSize: limits.AttachmentTotalSizeLimit,
AttachmentFileSize: limits.AttachmentFileSizeLimit, AttachmentFileSize: limits.AttachmentFileSizeLimit,
@ -67,6 +69,10 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, r *http.Request, v *vis
MessagesRemaining: stats.MessagesRemaining, MessagesRemaining: stats.MessagesRemaining,
Emails: stats.Emails, Emails: stats.Emails,
EmailsRemaining: stats.EmailsRemaining, EmailsRemaining: stats.EmailsRemaining,
SMS: stats.SMS,
SMSRemaining: stats.SMSRemaining,
Calls: stats.Calls,
CallsRemaining: stats.CallsRemaining,
Reservations: stats.Reservations, Reservations: stats.Reservations,
ReservationsRemaining: stats.ReservationsRemaining, ReservationsRemaining: stats.ReservationsRemaining,
AttachmentTotalSize: stats.AttachmentTotalSize, AttachmentTotalSize: stats.AttachmentTotalSize,

View File

@ -287,6 +287,8 @@ type apiAccountLimits struct {
Messages int64 `json:"messages"` Messages int64 `json:"messages"`
MessagesExpiryDuration int64 `json:"messages_expiry_duration"` MessagesExpiryDuration int64 `json:"messages_expiry_duration"`
Emails int64 `json:"emails"` Emails int64 `json:"emails"`
SMS int64 `json:"sms"`
Calls int64 `json:"calls"`
Reservations int64 `json:"reservations"` Reservations int64 `json:"reservations"`
AttachmentTotalSize int64 `json:"attachment_total_size"` AttachmentTotalSize int64 `json:"attachment_total_size"`
AttachmentFileSize int64 `json:"attachment_file_size"` AttachmentFileSize int64 `json:"attachment_file_size"`
@ -299,6 +301,10 @@ type apiAccountStats struct {
MessagesRemaining int64 `json:"messages_remaining"` MessagesRemaining int64 `json:"messages_remaining"`
Emails int64 `json:"emails"` Emails int64 `json:"emails"`
EmailsRemaining int64 `json:"emails_remaining"` EmailsRemaining int64 `json:"emails_remaining"`
SMS int64 `json:"sms"`
SMSRemaining int64 `json:"sms_remaining"`
Calls int64 `json:"calls"`
CallsRemaining int64 `json:"calls_remaining"`
Reservations int64 `json:"reservations"` Reservations int64 `json:"reservations"`
ReservationsRemaining int64 `json:"reservations_remaining"` ReservationsRemaining int64 `json:"reservations_remaining"`
AttachmentTotalSize int64 `json:"attachment_total_size"` AttachmentTotalSize int64 `json:"attachment_total_size"`

View File

@ -56,6 +56,8 @@ type visitor struct {
requestLimiter *rate.Limiter // Rate limiter for (almost) all requests (including messages) requestLimiter *rate.Limiter // Rate limiter for (almost) all requests (including messages)
messagesLimiter *util.FixedLimiter // Rate limiter for messages messagesLimiter *util.FixedLimiter // Rate limiter for messages
emailsLimiter *util.RateLimiter // Rate limiter for emails emailsLimiter *util.RateLimiter // Rate limiter for emails
smsLimiter *util.FixedLimiter // Rate limiter for SMS
callsLimiter *util.FixedLimiter // Rate limiter for calls
subscriptionLimiter *util.FixedLimiter // Fixed limiter for active subscriptions (ongoing connections) subscriptionLimiter *util.FixedLimiter // Fixed limiter for active subscriptions (ongoing connections)
bandwidthLimiter *util.RateLimiter // Limiter for attachment bandwidth downloads bandwidthLimiter *util.RateLimiter // Limiter for attachment bandwidth downloads
accountLimiter *rate.Limiter // Rate limiter for account creation, may be nil accountLimiter *rate.Limiter // Rate limiter for account creation, may be nil
@ -79,6 +81,8 @@ type visitorLimits struct {
EmailLimit int64 EmailLimit int64
EmailLimitBurst int EmailLimitBurst int
EmailLimitReplenish rate.Limit EmailLimitReplenish rate.Limit
SMSLimit int64
CallLimit int64
ReservationsLimit int64 ReservationsLimit int64
AttachmentTotalSizeLimit int64 AttachmentTotalSizeLimit int64
AttachmentFileSizeLimit int64 AttachmentFileSizeLimit int64
@ -91,6 +95,10 @@ type visitorStats struct {
MessagesRemaining int64 MessagesRemaining int64
Emails int64 Emails int64
EmailsRemaining int64 EmailsRemaining int64
SMS int64
SMSRemaining int64
Calls int64
CallsRemaining int64
Reservations int64 Reservations int64
ReservationsRemaining int64 ReservationsRemaining int64
AttachmentTotalSize int64 AttachmentTotalSize int64
@ -107,10 +115,12 @@ const (
) )
func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Manager, ip netip.Addr, user *user.User) *visitor { func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Manager, ip netip.Addr, user *user.User) *visitor {
var messages, emails int64 var messages, emails, sms, calls int64
if user != nil { if user != nil {
messages = user.Stats.Messages messages = user.Stats.Messages
emails = user.Stats.Emails emails = user.Stats.Emails
sms = user.Stats.SMS
calls = user.Stats.Calls
} }
v := &visitor{ v := &visitor{
config: conf, config: conf,
@ -124,11 +134,13 @@ func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Mana
requestLimiter: nil, // Set in resetLimiters requestLimiter: nil, // Set in resetLimiters
messagesLimiter: nil, // Set in resetLimiters, may be nil messagesLimiter: nil, // Set in resetLimiters, may be nil
emailsLimiter: nil, // Set in resetLimiters emailsLimiter: nil, // Set in resetLimiters
smsLimiter: nil, // Set in resetLimiters, may be nil
callsLimiter: nil, // Set in resetLimiters, may be nil
bandwidthLimiter: nil, // Set in resetLimiters bandwidthLimiter: nil, // Set in resetLimiters
accountLimiter: nil, // Set in resetLimiters, may be nil accountLimiter: nil, // Set in resetLimiters, may be nil
authLimiter: nil, // Set in resetLimiters, may be nil authLimiter: nil, // Set in resetLimiters, may be nil
} }
v.resetLimitersNoLock(messages, emails, false) v.resetLimitersNoLock(messages, emails, sms, calls, false)
return v return v
} }
@ -147,12 +159,22 @@ func (v *visitor) contextNoLock() log.Context {
"visitor_messages": info.Stats.Messages, "visitor_messages": info.Stats.Messages,
"visitor_messages_limit": info.Limits.MessageLimit, "visitor_messages_limit": info.Limits.MessageLimit,
"visitor_messages_remaining": info.Stats.MessagesRemaining, "visitor_messages_remaining": info.Stats.MessagesRemaining,
"visitor_emails": info.Stats.Emails,
"visitor_emails_limit": info.Limits.EmailLimit,
"visitor_emails_remaining": info.Stats.EmailsRemaining,
"visitor_request_limiter_limit": v.requestLimiter.Limit(), "visitor_request_limiter_limit": v.requestLimiter.Limit(),
"visitor_request_limiter_tokens": v.requestLimiter.Tokens(), "visitor_request_limiter_tokens": v.requestLimiter.Tokens(),
} }
if v.config.SMTPSenderFrom != "" {
fields["visitor_emails"] = info.Stats.Emails
fields["visitor_emails_limit"] = info.Limits.EmailLimit
fields["visitor_emails_remaining"] = info.Stats.EmailsRemaining
}
if v.config.TwilioAccount != "" {
fields["visitor_sms"] = info.Stats.SMS
fields["visitor_sms_limit"] = info.Limits.SMSLimit
fields["visitor_sms_remaining"] = info.Stats.SMSRemaining
fields["visitor_calls"] = info.Stats.Calls
fields["visitor_calls_limit"] = info.Limits.CallLimit
fields["visitor_calls_remaining"] = info.Stats.CallsRemaining
}
if v.authLimiter != nil { if v.authLimiter != nil {
fields["visitor_auth_limiter_limit"] = v.authLimiter.Limit() fields["visitor_auth_limiter_limit"] = v.authLimiter.Limit()
fields["visitor_auth_limiter_tokens"] = v.authLimiter.Tokens() fields["visitor_auth_limiter_tokens"] = v.authLimiter.Tokens()
@ -216,6 +238,18 @@ func (v *visitor) EmailAllowed() bool {
return v.emailsLimiter.Allow() return v.emailsLimiter.Allow()
} }
func (v *visitor) SMSAllowed() bool {
v.mu.RLock() // limiters could be replaced!
defer v.mu.RUnlock()
return v.smsLimiter.Allow()
}
func (v *visitor) CallAllowed() bool {
v.mu.RLock() // limiters could be replaced!
defer v.mu.RUnlock()
return v.callsLimiter.Allow()
}
func (v *visitor) SubscriptionAllowed() bool { func (v *visitor) SubscriptionAllowed() bool {
v.mu.RLock() // limiters could be replaced! v.mu.RLock() // limiters could be replaced!
defer v.mu.RUnlock() defer v.mu.RUnlock()
@ -296,6 +330,8 @@ func (v *visitor) Stats() *user.Stats {
return &user.Stats{ return &user.Stats{
Messages: v.messagesLimiter.Value(), Messages: v.messagesLimiter.Value(),
Emails: v.emailsLimiter.Value(), Emails: v.emailsLimiter.Value(),
SMS: v.smsLimiter.Value(),
Calls: v.callsLimiter.Value(),
} }
} }
@ -304,6 +340,8 @@ func (v *visitor) ResetStats() {
defer v.mu.RUnlock() defer v.mu.RUnlock()
v.emailsLimiter.Reset() v.emailsLimiter.Reset()
v.messagesLimiter.Reset() v.messagesLimiter.Reset()
v.smsLimiter.Reset()
v.callsLimiter.Reset()
} }
// User returns the visitor user, or nil if there is none // User returns the visitor user, or nil if there is none
@ -334,11 +372,11 @@ func (v *visitor) SetUser(u *user.User) {
shouldResetLimiters := v.user.TierID() != u.TierID() // TierID works with nil receiver shouldResetLimiters := v.user.TierID() != u.TierID() // TierID works with nil receiver
v.user = u // u may be nil! v.user = u // u may be nil!
if shouldResetLimiters { if shouldResetLimiters {
var messages, emails int64 var messages, emails, sms, calls int64
if u != nil { if u != nil {
messages, emails = u.Stats.Messages, u.Stats.Emails messages, emails, sms, calls = u.Stats.Messages, u.Stats.Emails, u.Stats.SMS, u.Stats.Calls
} }
v.resetLimitersNoLock(messages, emails, true) v.resetLimitersNoLock(messages, emails, sms, calls, true)
} }
} }
@ -353,11 +391,13 @@ func (v *visitor) MaybeUserID() string {
return "" return ""
} }
func (v *visitor) resetLimitersNoLock(messages, emails int64, enqueueUpdate bool) { func (v *visitor) resetLimitersNoLock(messages, emails, sms, calls int64, enqueueUpdate bool) {
limits := v.limitsNoLock() limits := v.limitsNoLock()
v.requestLimiter = rate.NewLimiter(limits.RequestLimitReplenish, limits.RequestLimitBurst) v.requestLimiter = rate.NewLimiter(limits.RequestLimitReplenish, limits.RequestLimitBurst)
v.messagesLimiter = util.NewFixedLimiterWithValue(limits.MessageLimit, messages) v.messagesLimiter = util.NewFixedLimiterWithValue(limits.MessageLimit, messages)
v.emailsLimiter = util.NewRateLimiterWithValue(limits.EmailLimitReplenish, limits.EmailLimitBurst, emails) v.emailsLimiter = util.NewRateLimiterWithValue(limits.EmailLimitReplenish, limits.EmailLimitBurst, emails)
v.smsLimiter = util.NewFixedLimiterWithValue(limits.SMSLimit, sms)
v.callsLimiter = util.NewFixedLimiterWithValue(limits.CallLimit, calls)
v.bandwidthLimiter = util.NewBytesLimiter(int(limits.AttachmentBandwidthLimit), oneDay) v.bandwidthLimiter = util.NewBytesLimiter(int(limits.AttachmentBandwidthLimit), oneDay)
if v.user == nil { if v.user == nil {
v.accountLimiter = rate.NewLimiter(rate.Every(v.config.VisitorAccountCreationLimitReplenish), v.config.VisitorAccountCreationLimitBurst) v.accountLimiter = rate.NewLimiter(rate.Every(v.config.VisitorAccountCreationLimitReplenish), v.config.VisitorAccountCreationLimitBurst)
@ -370,6 +410,8 @@ func (v *visitor) resetLimitersNoLock(messages, emails int64, enqueueUpdate bool
go v.userManager.EnqueueUserStats(v.user.ID, &user.Stats{ go v.userManager.EnqueueUserStats(v.user.ID, &user.Stats{
Messages: messages, Messages: messages,
Emails: emails, Emails: emails,
SMS: sms,
Calls: calls,
}) })
} }
log.Fields(v.contextNoLock()).Debug("Rate limiters reset for visitor") // Must be after function, because contextNoLock() describes rate limiters log.Fields(v.contextNoLock()).Debug("Rate limiters reset for visitor") // Must be after function, because contextNoLock() describes rate limiters
@ -398,6 +440,8 @@ func tierBasedVisitorLimits(conf *Config, tier *user.Tier) *visitorLimits {
EmailLimit: tier.EmailLimit, EmailLimit: tier.EmailLimit,
EmailLimitBurst: util.MinMax(int(float64(tier.EmailLimit)*visitorEmailLimitBurstRate), conf.VisitorEmailLimitBurst, visitorEmailLimitBurstMax), EmailLimitBurst: util.MinMax(int(float64(tier.EmailLimit)*visitorEmailLimitBurstRate), conf.VisitorEmailLimitBurst, visitorEmailLimitBurstMax),
EmailLimitReplenish: dailyLimitToRate(tier.EmailLimit), EmailLimitReplenish: dailyLimitToRate(tier.EmailLimit),
SMSLimit: tier.SMSLimit,
CallLimit: tier.CallLimit,
ReservationsLimit: tier.ReservationLimit, ReservationsLimit: tier.ReservationLimit,
AttachmentTotalSizeLimit: tier.AttachmentTotalSizeLimit, AttachmentTotalSizeLimit: tier.AttachmentTotalSizeLimit,
AttachmentFileSizeLimit: tier.AttachmentFileSizeLimit, AttachmentFileSizeLimit: tier.AttachmentFileSizeLimit,
@ -420,6 +464,8 @@ func configBasedVisitorLimits(conf *Config) *visitorLimits {
EmailLimit: replenishDurationToDailyLimit(conf.VisitorEmailLimitReplenish), // Approximation! EmailLimit: replenishDurationToDailyLimit(conf.VisitorEmailLimitReplenish), // Approximation!
EmailLimitBurst: conf.VisitorEmailLimitBurst, EmailLimitBurst: conf.VisitorEmailLimitBurst,
EmailLimitReplenish: rate.Every(conf.VisitorEmailLimitReplenish), EmailLimitReplenish: rate.Every(conf.VisitorEmailLimitReplenish),
SMSLimit: int64(conf.VisitorSMSDailyLimit),
CallLimit: int64(conf.VisitorCallDailyLimit),
ReservationsLimit: visitorDefaultReservationsLimit, ReservationsLimit: visitorDefaultReservationsLimit,
AttachmentTotalSizeLimit: conf.VisitorAttachmentTotalSizeLimit, AttachmentTotalSizeLimit: conf.VisitorAttachmentTotalSizeLimit,
AttachmentFileSizeLimit: conf.AttachmentFileSizeLimit, AttachmentFileSizeLimit: conf.AttachmentFileSizeLimit,
@ -465,12 +511,18 @@ func (v *visitor) Info() (*visitorInfo, error) {
func (v *visitor) infoLightNoLock() *visitorInfo { func (v *visitor) infoLightNoLock() *visitorInfo {
messages := v.messagesLimiter.Value() messages := v.messagesLimiter.Value()
emails := v.emailsLimiter.Value() emails := v.emailsLimiter.Value()
sms := v.smsLimiter.Value()
calls := v.callsLimiter.Value()
limits := v.limitsNoLock() limits := v.limitsNoLock()
stats := &visitorStats{ stats := &visitorStats{
Messages: messages, Messages: messages,
MessagesRemaining: zeroIfNegative(limits.MessageLimit - messages), MessagesRemaining: zeroIfNegative(limits.MessageLimit - messages),
Emails: emails, Emails: emails,
EmailsRemaining: zeroIfNegative(limits.EmailLimit - emails), EmailsRemaining: zeroIfNegative(limits.EmailLimit - emails),
SMS: sms,
SMSRemaining: zeroIfNegative(limits.SMSLimit - sms),
Calls: calls,
CallsRemaining: zeroIfNegative(limits.CallLimit - calls),
} }
return &visitorInfo{ return &visitorInfo{
Limits: limits, Limits: limits,

View File

@ -55,6 +55,8 @@ const (
messages_limit INT NOT NULL, messages_limit INT NOT NULL,
messages_expiry_duration INT NOT NULL, messages_expiry_duration INT NOT NULL,
emails_limit INT NOT NULL, emails_limit INT NOT NULL,
sms_limit INT NOT NULL,
calls_limit INT NOT NULL,
reservations_limit INT NOT NULL, reservations_limit INT NOT NULL,
attachment_file_size_limit INT NOT NULL, attachment_file_size_limit INT NOT NULL,
attachment_total_size_limit INT NOT NULL, attachment_total_size_limit INT NOT NULL,
@ -76,6 +78,8 @@ const (
sync_topic TEXT NOT NULL, sync_topic TEXT NOT NULL,
stats_messages INT NOT NULL DEFAULT (0), stats_messages INT NOT NULL DEFAULT (0),
stats_emails INT NOT NULL DEFAULT (0), stats_emails INT NOT NULL DEFAULT (0),
stats_sms INT NOT NULL DEFAULT (0),
stats_calls INT NOT NULL DEFAULT (0),
stripe_customer_id TEXT, stripe_customer_id TEXT,
stripe_subscription_id TEXT, stripe_subscription_id TEXT,
stripe_subscription_status TEXT, stripe_subscription_status TEXT,
@ -123,26 +127,26 @@ const (
` `
selectUserByIDQuery = ` selectUserByIDQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stats_sms, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
FROM user u FROM user u
LEFT JOIN tier t on t.id = u.tier_id LEFT JOIN tier t on t.id = u.tier_id
WHERE u.id = ? WHERE u.id = ?
` `
selectUserByNameQuery = ` selectUserByNameQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stats_sms, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
FROM user u FROM user u
LEFT JOIN tier t on t.id = u.tier_id LEFT JOIN tier t on t.id = u.tier_id
WHERE user = ? WHERE user = ?
` `
selectUserByTokenQuery = ` selectUserByTokenQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stats_sms, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
FROM user u FROM user u
JOIN user_token tk on u.id = tk.user_id JOIN user_token tk on u.id = tk.user_id
LEFT JOIN tier t on t.id = u.tier_id LEFT JOIN tier t on t.id = u.tier_id
WHERE tk.token = ? AND (tk.expires = 0 OR tk.expires >= ?) WHERE tk.token = ? AND (tk.expires = 0 OR tk.expires >= ?)
` `
selectUserByStripeCustomerIDQuery = ` selectUserByStripeCustomerIDQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stats_sms, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
FROM user u FROM user u
LEFT JOIN tier t on t.id = u.tier_id LEFT JOIN tier t on t.id = u.tier_id
WHERE u.stripe_customer_id = ? WHERE u.stripe_customer_id = ?
@ -173,8 +177,8 @@ const (
updateUserPassQuery = `UPDATE user SET pass = ? WHERE user = ?` updateUserPassQuery = `UPDATE user SET pass = ? WHERE user = ?`
updateUserRoleQuery = `UPDATE user SET role = ? WHERE user = ?` updateUserRoleQuery = `UPDATE user SET role = ? WHERE user = ?`
updateUserPrefsQuery = `UPDATE user SET prefs = ? WHERE id = ?` updateUserPrefsQuery = `UPDATE user SET prefs = ? WHERE id = ?`
updateUserStatsQuery = `UPDATE user SET stats_messages = ?, stats_emails = ? WHERE id = ?` updateUserStatsQuery = `UPDATE user SET stats_messages = ?, stats_emails = ?, stats_sms = ?, stats_calls = ? WHERE id = ?`
updateUserStatsResetAllQuery = `UPDATE user SET stats_messages = 0, stats_emails = 0` updateUserStatsResetAllQuery = `UPDATE user SET stats_messages = 0, stats_emails = 0, stats_sms = 0, stats_calls = 0`
updateUserDeletedQuery = `UPDATE user SET deleted = ? WHERE id = ?` updateUserDeletedQuery = `UPDATE user SET deleted = ? WHERE id = ?`
deleteUsersMarkedQuery = `DELETE FROM user WHERE deleted < ?` deleteUsersMarkedQuery = `DELETE FROM user WHERE deleted < ?`
deleteUserQuery = `DELETE FROM user WHERE user = ?` deleteUserQuery = `DELETE FROM user WHERE user = ?`
@ -258,25 +262,25 @@ const (
` `
insertTierQuery = ` insertTierQuery = `
INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id) INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, sms_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
` `
updateTierQuery = ` updateTierQuery = `
UPDATE tier UPDATE tier
SET name = ?, messages_limit = ?, messages_expiry_duration = ?, emails_limit = ?, reservations_limit = ?, attachment_file_size_limit = ?, attachment_total_size_limit = ?, attachment_expiry_duration = ?, attachment_bandwidth_limit = ?, stripe_monthly_price_id = ?, stripe_yearly_price_id = ? SET name = ?, messages_limit = ?, messages_expiry_duration = ?, emails_limit = ?, sms_limit = ?, calls_limit = ?, reservations_limit = ?, attachment_file_size_limit = ?, attachment_total_size_limit = ?, attachment_expiry_duration = ?, attachment_bandwidth_limit = ?, stripe_monthly_price_id = ?, stripe_yearly_price_id = ?
WHERE code = ? WHERE code = ?
` `
selectTiersQuery = ` selectTiersQuery = `
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, sms_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
FROM tier FROM tier
` `
selectTierByCodeQuery = ` selectTierByCodeQuery = `
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, sms_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
FROM tier FROM tier
WHERE code = ? WHERE code = ?
` `
selectTierByPriceIDQuery = ` selectTierByPriceIDQuery = `
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, sms_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
FROM tier FROM tier
WHERE (stripe_monthly_price_id = ? OR stripe_yearly_price_id = ?) WHERE (stripe_monthly_price_id = ? OR stripe_yearly_price_id = ?)
` `
@ -293,7 +297,7 @@ const (
// Schema management queries // Schema management queries
const ( const (
currentSchemaVersion = 3 currentSchemaVersion = 4
insertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)` insertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
updateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1` updateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1`
selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1` selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
@ -391,12 +395,21 @@ const (
CREATE UNIQUE INDEX idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id); CREATE UNIQUE INDEX idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id);
CREATE UNIQUE INDEX idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id); CREATE UNIQUE INDEX idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id);
` `
// 3 -> 4
migrate3To4UpdateQueries = `
ALTER TABLE tier ADD COLUMN sms_limit INT NOT NULL DEFAULT (0);
ALTER TABLE tier ADD COLUMN calls_limit INT NOT NULL DEFAULT (0);
ALTER TABLE user ADD COLUMN stats_sms INT NOT NULL DEFAULT (0);
ALTER TABLE user ADD COLUMN stats_calls INT NOT NULL DEFAULT (0);
`
) )
var ( var (
migrations = map[int]func(db *sql.DB) error{ migrations = map[int]func(db *sql.DB) error{
1: migrateFrom1, 1: migrateFrom1,
2: migrateFrom2, 2: migrateFrom2,
3: migrateFrom3,
} }
) )
@ -700,9 +713,11 @@ func (a *Manager) writeUserStatsQueue() error {
"user_id": userID, "user_id": userID,
"messages_count": update.Messages, "messages_count": update.Messages,
"emails_count": update.Emails, "emails_count": update.Emails,
"sms_count": update.SMS,
"calls_count": update.Calls,
}). }).
Trace("Updating stats for user %s", userID) Trace("Updating stats for user %s", userID)
if _, err := tx.Exec(updateUserStatsQuery, update.Messages, update.Emails, userID); err != nil { if _, err := tx.Exec(updateUserStatsQuery, update.Messages, update.Emails, update.SMS, update.Calls, userID); err != nil {
return err return err
} }
} }
@ -911,12 +926,12 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
defer rows.Close() defer rows.Close()
var id, username, hash, role, prefs, syncTopic string var id, username, hash, role, prefs, syncTopic string
var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripeSubscriptionInterval, stripeMonthlyPriceID, stripeYearlyPriceID, tierID, tierCode, tierName sql.NullString var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripeSubscriptionInterval, stripeMonthlyPriceID, stripeYearlyPriceID, tierID, tierCode, tierName sql.NullString
var messages, emails int64 var messages, emails, sms, calls int64
var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64 var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64
if !rows.Next() { if !rows.Next() {
return nil, ErrUserNotFound return nil, ErrUserNotFound
} }
if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionInterval, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierID, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil { if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &sms, &calls, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionInterval, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierID, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil {
return nil, err return nil, err
} else if err := rows.Err(); err != nil { } else if err := rows.Err(); err != nil {
return nil, err return nil, err
@ -931,6 +946,8 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
Stats: &Stats{ Stats: &Stats{
Messages: messages, Messages: messages,
Emails: emails, Emails: emails,
SMS: sms,
Calls: calls,
}, },
Billing: &Billing{ Billing: &Billing{
StripeCustomerID: stripeCustomerID.String, // May be empty StripeCustomerID: stripeCustomerID.String, // May be empty
@ -1259,7 +1276,7 @@ func (a *Manager) AddTier(tier *Tier) error {
if tier.ID == "" { if tier.ID == "" {
tier.ID = util.RandomStringPrefix(tierIDPrefix, tierIDLength) tier.ID = util.RandomStringPrefix(tierIDPrefix, tierIDLength)
} }
if _, err := a.db.Exec(insertTierQuery, tier.ID, tier.Code, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripeMonthlyPriceID), nullString(tier.StripeYearlyPriceID)); err != nil { if _, err := a.db.Exec(insertTierQuery, tier.ID, tier.Code, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.SMSLimit, tier.CallLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripeMonthlyPriceID), nullString(tier.StripeYearlyPriceID)); err != nil {
return err return err
} }
return nil return nil
@ -1267,7 +1284,7 @@ func (a *Manager) AddTier(tier *Tier) error {
// UpdateTier updates a tier's properties in the database // UpdateTier updates a tier's properties in the database
func (a *Manager) UpdateTier(tier *Tier) error { func (a *Manager) UpdateTier(tier *Tier) error {
if _, err := a.db.Exec(updateTierQuery, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripeMonthlyPriceID), nullString(tier.StripeYearlyPriceID), tier.Code); err != nil { if _, err := a.db.Exec(updateTierQuery, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.SMSLimit, tier.CallLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripeMonthlyPriceID), nullString(tier.StripeYearlyPriceID), tier.Code); err != nil {
return err return err
} }
return nil return nil
@ -1336,11 +1353,11 @@ func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) {
func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) { func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
var id, code, name string var id, code, name string
var stripeMonthlyPriceID, stripeYearlyPriceID sql.NullString var stripeMonthlyPriceID, stripeYearlyPriceID sql.NullString
var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit sql.NullInt64 var messagesLimit, messagesExpiryDuration, emailsLimit, smsLimit, callsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit sql.NullInt64
if !rows.Next() { if !rows.Next() {
return nil, ErrTierNotFound return nil, ErrTierNotFound
} }
if err := rows.Scan(&id, &code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil { if err := rows.Scan(&id, &code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &smsLimit, &callsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil {
return nil, err return nil, err
} else if err := rows.Err(); err != nil { } else if err := rows.Err(); err != nil {
return nil, err return nil, err
@ -1353,6 +1370,8 @@ func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
MessageLimit: messagesLimit.Int64, MessageLimit: messagesLimit.Int64,
MessageExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second, MessageExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
EmailLimit: emailsLimit.Int64, EmailLimit: emailsLimit.Int64,
SMSLimit: smsLimit.Int64,
CallLimit: callsLimit.Int64,
ReservationLimit: reservationsLimit.Int64, ReservationLimit: reservationsLimit.Int64,
AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64, AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64, AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
@ -1495,6 +1514,22 @@ func migrateFrom2(db *sql.DB) error {
return tx.Commit() return tx.Commit()
} }
func migrateFrom3(db *sql.DB) error {
log.Tag(tag).Info("Migrating user database schema: from 3 to 4")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(migrate3To4UpdateQueries); err != nil {
return err
}
if _, err := tx.Exec(updateSchemaVersion, 4); err != nil {
return err
}
return tx.Commit()
}
func nullString(s string) sql.NullString { func nullString(s string) sql.NullString {
if s == "" { if s == "" {
return sql.NullString{} return sql.NullString{}

View File

@ -86,6 +86,8 @@ type Tier struct {
MessageLimit int64 // Daily message limit MessageLimit int64 // Daily message limit
MessageExpiryDuration time.Duration // Cache duration for messages MessageExpiryDuration time.Duration // Cache duration for messages
EmailLimit int64 // Daily email limit EmailLimit int64 // Daily email limit
SMSLimit int64 // Daily SMS limit
CallLimit int64 // Daily phone call limit
ReservationLimit int64 // Number of topic reservations allowed by user ReservationLimit int64 // Number of topic reservations allowed by user
AttachmentFileSizeLimit int64 // Max file size per file (bytes) AttachmentFileSizeLimit int64 // Max file size per file (bytes)
AttachmentTotalSizeLimit int64 // Total file size for all files of this user (bytes) AttachmentTotalSizeLimit int64 // Total file size for all files of this user (bytes)
@ -131,6 +133,8 @@ type NotificationPrefs struct {
type Stats struct { type Stats struct {
Messages int64 Messages int64
Emails int64 Emails int64
SMS int64
Calls int64
} }
// Billing is a struct holding a user's billing information // Billing is a struct holding a user's billing information