From f9e2d6ddcbe829ca1297ebc7bf9c455836281508 Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Sun, 7 May 2023 11:59:15 -0400 Subject: [PATCH] Add limiters and database changes --- cmd/serve.go | 6 ++++ cmd/tier.go | 16 +++++++++ server/config.go | 4 +++ server/errors.go | 2 ++ server/server.go | 6 +++- server/server.yml | 10 ++++-- server/server_account.go | 6 ++++ server/types.go | 6 ++++ server/visitor.go | 70 ++++++++++++++++++++++++++++++++----- user/manager.go | 75 +++++++++++++++++++++++++++++----------- user/types.go | 4 +++ 11 files changed, 173 insertions(+), 32 deletions(-) diff --git a/cmd/serve.go b/cmd/serve.go index bef09e1c..6c729753 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -83,6 +83,8 @@ var flagsServe = append( altsrc.NewStringFlag(&cli.StringFlag{Name: "visitor-request-limit-exempt-hosts", Aliases: []string{"visitor_request_limit_exempt_hosts"}, EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_EXEMPT_HOSTS"}, Value: "", Usage: "hostnames and/or IP addresses of hosts that will be exempt from the visitor request limit"}), altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-message-daily-limit", Aliases: []string{"visitor_message_daily_limit"}, EnvVars: []string{"NTFY_VISITOR_MESSAGE_DAILY_LIMIT"}, Value: server.DefaultVisitorMessageDailyLimit, Usage: "max messages per visitor per day, derived from request limit if unset"}), altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-email-limit-burst", Aliases: []string{"visitor_email_limit_burst"}, EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_BURST"}, Value: server.DefaultVisitorEmailLimitBurst, Usage: "initial limit of e-mails per visitor"}), + altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-sms-daily-limit", Aliases: []string{"visitor_sms_daily_limit"}, EnvVars: []string{"NTFY_VISITOR_SMS_DAILY_LIMIT"}, Value: server.DefaultVisitorSMSDailyLimit, Usage: "max number of SMS messages per visitor per day"}), + altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-call-daily-limit", Aliases: []string{"visitor_call_daily_limit"}, EnvVars: []string{"NTFY_VISITOR_CALL_DAILY_LIMIT"}, Value: server.DefaultVisitorCallDailyLimit, Usage: "max number of phone calls per visitor per day"}), altsrc.NewDurationFlag(&cli.DurationFlag{Name: "visitor-email-limit-replenish", Aliases: []string{"visitor_email_limit_replenish"}, EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_REPLENISH"}, Value: server.DefaultVisitorEmailLimitReplenish, Usage: "interval at which burst limit is replenished (one per x)"}), altsrc.NewBoolFlag(&cli.BoolFlag{Name: "visitor-subscriber-rate-limiting", Aliases: []string{"visitor_subscriber_rate_limiting"}, EnvVars: []string{"NTFY_VISITOR_SUBSCRIBER_RATE_LIMITING"}, Value: false, Usage: "enables subscriber-based rate limiting"}), altsrc.NewBoolFlag(&cli.BoolFlag{Name: "behind-proxy", Aliases: []string{"behind_proxy", "P"}, EnvVars: []string{"NTFY_BEHIND_PROXY"}, Value: false, Usage: "if set, use X-Forwarded-For header to determine visitor IP address (for rate limiting)"}), @@ -168,6 +170,8 @@ func execServe(c *cli.Context) error { visitorMessageDailyLimit := c.Int("visitor-message-daily-limit") visitorEmailLimitBurst := c.Int("visitor-email-limit-burst") visitorEmailLimitReplenish := c.Duration("visitor-email-limit-replenish") + visitorSMSDailyLimit := c.Int("visitor-sms-daily-limit") + visitorCallDailyLimit := c.Int("visitor-call-daily-limit") behindProxy := c.Bool("behind-proxy") stripeSecretKey := c.String("stripe-secret-key") stripeWebhookKey := c.String("stripe-webhook-key") @@ -329,6 +333,8 @@ func execServe(c *cli.Context) error { conf.VisitorMessageDailyLimit = visitorMessageDailyLimit conf.VisitorEmailLimitBurst = visitorEmailLimitBurst conf.VisitorEmailLimitReplenish = visitorEmailLimitReplenish + conf.VisitorSMSDailyLimit = visitorSMSDailyLimit + conf.VisitorCallDailyLimit = visitorCallDailyLimit conf.VisitorSubscriberRateLimiting = visitorSubscriberRateLimiting conf.BehindProxy = behindProxy conf.StripeSecretKey = stripeSecretKey diff --git a/cmd/tier.go b/cmd/tier.go index c0b83d71..6b95bdd2 100644 --- a/cmd/tier.go +++ b/cmd/tier.go @@ -18,6 +18,8 @@ const ( defaultMessageLimit = 5000 defaultMessageExpiryDuration = "12h" defaultEmailLimit = 20 + defaultSMSLimit = 10 + defaultCallLimit = 10 defaultReservationLimit = 3 defaultAttachmentFileSizeLimit = "15M" defaultAttachmentTotalSizeLimit = "100M" @@ -48,6 +50,8 @@ var cmdTier = &cli.Command{ &cli.Int64Flag{Name: "message-limit", Value: defaultMessageLimit, Usage: "daily message limit"}, &cli.StringFlag{Name: "message-expiry-duration", Value: defaultMessageExpiryDuration, Usage: "duration after which messages are deleted"}, &cli.Int64Flag{Name: "email-limit", Value: defaultEmailLimit, Usage: "daily email limit"}, + &cli.Int64Flag{Name: "sms-limit", Value: defaultSMSLimit, Usage: "daily SMS limit"}, + &cli.Int64Flag{Name: "call-limit", Value: defaultCallLimit, Usage: "daily phone call limit"}, &cli.Int64Flag{Name: "reservation-limit", Value: defaultReservationLimit, Usage: "topic reservation limit"}, &cli.StringFlag{Name: "attachment-file-size-limit", Value: defaultAttachmentFileSizeLimit, Usage: "per-attachment file size limit"}, &cli.StringFlag{Name: "attachment-total-size-limit", Value: defaultAttachmentTotalSizeLimit, Usage: "total size limit of attachments for the user"}, @@ -91,6 +95,8 @@ Examples: &cli.Int64Flag{Name: "message-limit", Usage: "daily message limit"}, &cli.StringFlag{Name: "message-expiry-duration", Usage: "duration after which messages are deleted"}, &cli.Int64Flag{Name: "email-limit", Usage: "daily email limit"}, + &cli.Int64Flag{Name: "sms-limit", Usage: "daily SMS limit"}, + &cli.Int64Flag{Name: "call-limit", Usage: "daily phone call limit"}, &cli.Int64Flag{Name: "reservation-limit", Usage: "topic reservation limit"}, &cli.StringFlag{Name: "attachment-file-size-limit", Usage: "per-attachment file size limit"}, &cli.StringFlag{Name: "attachment-total-size-limit", Usage: "total size limit of attachments for the user"}, @@ -215,6 +221,8 @@ func execTierAdd(c *cli.Context) error { MessageLimit: c.Int64("message-limit"), MessageExpiryDuration: messageExpiryDuration, EmailLimit: c.Int64("email-limit"), + SMSLimit: c.Int64("sms-limit"), + CallLimit: c.Int64("call-limit"), ReservationLimit: c.Int64("reservation-limit"), AttachmentFileSizeLimit: attachmentFileSizeLimit, AttachmentTotalSizeLimit: attachmentTotalSizeLimit, @@ -267,6 +275,12 @@ func execTierChange(c *cli.Context) error { if c.IsSet("email-limit") { tier.EmailLimit = c.Int64("email-limit") } + if c.IsSet("sms-limit") { + tier.SMSLimit = c.Int64("sms-limit") + } + if c.IsSet("call-limit") { + tier.CallLimit = c.Int64("call-limit") + } if c.IsSet("reservation-limit") { tier.ReservationLimit = c.Int64("reservation-limit") } @@ -357,6 +371,8 @@ func printTier(c *cli.Context, tier *user.Tier) { fmt.Fprintf(c.App.ErrWriter, "- Message limit: %d\n", tier.MessageLimit) fmt.Fprintf(c.App.ErrWriter, "- Message expiry duration: %s (%d seconds)\n", tier.MessageExpiryDuration.String(), int64(tier.MessageExpiryDuration.Seconds())) fmt.Fprintf(c.App.ErrWriter, "- Email limit: %d\n", tier.EmailLimit) + fmt.Fprintf(c.App.ErrWriter, "- SMS limit: %d\n", tier.SMSLimit) + fmt.Fprintf(c.App.ErrWriter, "- Phone call limit: %d\n", tier.CallLimit) fmt.Fprintf(c.App.ErrWriter, "- Reservation limit: %d\n", tier.ReservationLimit) fmt.Fprintf(c.App.ErrWriter, "- Attachment file size limit: %s\n", util.FormatSize(tier.AttachmentFileSizeLimit)) fmt.Fprintf(c.App.ErrWriter, "- Attachment total size limit: %s\n", util.FormatSize(tier.AttachmentTotalSizeLimit)) diff --git a/server/config.go b/server/config.go index 4f1cbef6..b6d57d90 100644 --- a/server/config.go +++ b/server/config.go @@ -47,6 +47,8 @@ const ( DefaultVisitorMessageDailyLimit = 0 DefaultVisitorEmailLimitBurst = 16 DefaultVisitorEmailLimitReplenish = time.Hour + DefaultVisitorSMSDailyLimit = 10 + DefaultVisitorCallDailyLimit = 10 DefaultVisitorAccountCreationLimitBurst = 3 DefaultVisitorAccountCreationLimitReplenish = 24 * time.Hour DefaultVisitorAuthFailureLimitBurst = 30 @@ -126,6 +128,8 @@ type Config struct { VisitorMessageDailyLimit int VisitorEmailLimitBurst int VisitorEmailLimitReplenish time.Duration + VisitorSMSDailyLimit int + VisitorCallDailyLimit int VisitorAccountCreationLimitBurst int VisitorAccountCreationLimitReplenish time.Duration VisitorAuthFailureLimitBurst int diff --git a/server/errors.go b/server/errors.go index 236b4e0c..d02fb071 100644 --- a/server/errors.go +++ b/server/errors.go @@ -126,6 +126,8 @@ var ( errHTTPTooManyRequestsLimitReservations = &errHTTP{42907, http.StatusTooManyRequests, "limit reached: too many topic reservations for this user", "", nil} errHTTPTooManyRequestsLimitMessages = &errHTTP{42908, http.StatusTooManyRequests, "limit reached: daily message quota reached", "https://ntfy.sh/docs/publish/#limitations", nil} errHTTPTooManyRequestsLimitAuthFailure = &errHTTP{42909, http.StatusTooManyRequests, "limit reached: too many auth failures", "https://ntfy.sh/docs/publish/#limitations", nil} // FIXME document limit + errHTTPTooManyRequestsLimitSMS = &errHTTP{42910, http.StatusTooManyRequests, "limit reached: daily SMS quota reached", "https://ntfy.sh/docs/publish/#limitations", nil} + errHTTPTooManyRequestsLimitCalls = &errHTTP{42911, http.StatusTooManyRequests, "limit reached: daily phone call quota reached", "https://ntfy.sh/docs/publish/#limitations", nil} errHTTPInternalError = &errHTTP{50001, http.StatusInternalServerError, "internal server error", "", nil} errHTTPInternalErrorInvalidPath = &errHTTP{50002, http.StatusInternalServerError, "internal server error: invalid path", "", nil} errHTTPInternalErrorMissingBaseURL = &errHTTP{50003, http.StatusInternalServerError, "internal server error: base-url must be be configured for this feature", "https://ntfy.sh/docs/config/", nil} diff --git a/server/server.go b/server/server.go index 1a9309ce..8c2f83ce 100644 --- a/server/server.go +++ b/server/server.go @@ -683,6 +683,10 @@ func (s *Server) handlePublishInternal(r *http.Request, v *visitor) (*message, e return nil, errHTTPTooManyRequestsLimitMessages.With(t) } else if email != "" && !vrate.EmailAllowed() { return nil, errHTTPTooManyRequestsLimitEmails.With(t) + } else if sms != "" && !vrate.SMSAllowed() { + return nil, errHTTPTooManyRequestsLimitSMS.With(t) + } else if call != "" && !vrate.CallAllowed() { + return nil, errHTTPTooManyRequestsLimitCalls.With(t) } if m.PollID != "" { m = newPollRequestMessage(t.ID, m.PollID) @@ -726,7 +730,7 @@ func (s *Server) handlePublishInternal(r *http.Request, v *visitor) (*message, e if s.config.TwilioAccount != "" && sms != "" { go s.sendSMS(v, r, m, sms) } - if call != "" { + if s.config.TwilioAccount != "" && call != "" { go s.callPhone(v, r, m, call) } if s.config.UpstreamBaseURL != "" { diff --git a/server/server.yml b/server/server.yml index 9e515ec3..fb4d1d99 100644 --- a/server/server.yml +++ b/server/server.yml @@ -224,11 +224,17 @@ # visitor-request-limit-exempt-hosts: "" # Rate limiting: Hard daily limit of messages per visitor and day. The limit is reset -# every day at midnight UTC. If the limit is not set (or set to zero), the request -# limit (see above) governs the upper limit. +# every day at midnight UTC. If the limit is not set (or set to zero), the request limit (see above) +# governs the upper limit. SMS and calls are only supported if the twilio-settings are properly configured. # # visitor-message-daily-limit: 0 +# Rate limiting: Daily limit of SMS and calls per visitor and day. The limit is reset every day +# at midnight UTC. SMS and calls are only supported if the twilio-settings are properly configured. +# +# visitor-sms-daily-limit: 10 +# visitor-call-daily-limit: 10 + # Rate limiting: Allowed emails per visitor: # - visitor-email-limit-burst is the initial bucket of emails each visitor has # - visitor-email-limit-replenish is the rate at which the bucket is refilled diff --git a/server/server_account.go b/server/server_account.go index 1b2c0ce4..bdc42903 100644 --- a/server/server_account.go +++ b/server/server_account.go @@ -56,6 +56,8 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, r *http.Request, v *vis Messages: limits.MessageLimit, MessagesExpiryDuration: int64(limits.MessageExpiryDuration.Seconds()), Emails: limits.EmailLimit, + SMS: limits.SMSLimit, + Calls: limits.CallLimit, Reservations: limits.ReservationsLimit, AttachmentTotalSize: limits.AttachmentTotalSizeLimit, AttachmentFileSize: limits.AttachmentFileSizeLimit, @@ -67,6 +69,10 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, r *http.Request, v *vis MessagesRemaining: stats.MessagesRemaining, Emails: stats.Emails, EmailsRemaining: stats.EmailsRemaining, + SMS: stats.SMS, + SMSRemaining: stats.SMSRemaining, + Calls: stats.Calls, + CallsRemaining: stats.CallsRemaining, Reservations: stats.Reservations, ReservationsRemaining: stats.ReservationsRemaining, AttachmentTotalSize: stats.AttachmentTotalSize, diff --git a/server/types.go b/server/types.go index 563cafbb..ae6724f5 100644 --- a/server/types.go +++ b/server/types.go @@ -287,6 +287,8 @@ type apiAccountLimits struct { Messages int64 `json:"messages"` MessagesExpiryDuration int64 `json:"messages_expiry_duration"` Emails int64 `json:"emails"` + SMS int64 `json:"sms"` + Calls int64 `json:"calls"` Reservations int64 `json:"reservations"` AttachmentTotalSize int64 `json:"attachment_total_size"` AttachmentFileSize int64 `json:"attachment_file_size"` @@ -299,6 +301,10 @@ type apiAccountStats struct { MessagesRemaining int64 `json:"messages_remaining"` Emails int64 `json:"emails"` EmailsRemaining int64 `json:"emails_remaining"` + SMS int64 `json:"sms"` + SMSRemaining int64 `json:"sms_remaining"` + Calls int64 `json:"calls"` + CallsRemaining int64 `json:"calls_remaining"` Reservations int64 `json:"reservations"` ReservationsRemaining int64 `json:"reservations_remaining"` AttachmentTotalSize int64 `json:"attachment_total_size"` diff --git a/server/visitor.go b/server/visitor.go index 63a3ac60..4de51e67 100644 --- a/server/visitor.go +++ b/server/visitor.go @@ -56,6 +56,8 @@ type visitor struct { requestLimiter *rate.Limiter // Rate limiter for (almost) all requests (including messages) messagesLimiter *util.FixedLimiter // Rate limiter for messages emailsLimiter *util.RateLimiter // Rate limiter for emails + smsLimiter *util.FixedLimiter // Rate limiter for SMS + callsLimiter *util.FixedLimiter // Rate limiter for calls subscriptionLimiter *util.FixedLimiter // Fixed limiter for active subscriptions (ongoing connections) bandwidthLimiter *util.RateLimiter // Limiter for attachment bandwidth downloads accountLimiter *rate.Limiter // Rate limiter for account creation, may be nil @@ -79,6 +81,8 @@ type visitorLimits struct { EmailLimit int64 EmailLimitBurst int EmailLimitReplenish rate.Limit + SMSLimit int64 + CallLimit int64 ReservationsLimit int64 AttachmentTotalSizeLimit int64 AttachmentFileSizeLimit int64 @@ -91,6 +95,10 @@ type visitorStats struct { MessagesRemaining int64 Emails int64 EmailsRemaining int64 + SMS int64 + SMSRemaining int64 + Calls int64 + CallsRemaining int64 Reservations int64 ReservationsRemaining int64 AttachmentTotalSize int64 @@ -107,10 +115,12 @@ const ( ) func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Manager, ip netip.Addr, user *user.User) *visitor { - var messages, emails int64 + var messages, emails, sms, calls int64 if user != nil { messages = user.Stats.Messages emails = user.Stats.Emails + sms = user.Stats.SMS + calls = user.Stats.Calls } v := &visitor{ config: conf, @@ -124,11 +134,13 @@ func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Mana requestLimiter: nil, // Set in resetLimiters messagesLimiter: nil, // Set in resetLimiters, may be nil emailsLimiter: nil, // Set in resetLimiters + smsLimiter: nil, // Set in resetLimiters, may be nil + callsLimiter: nil, // Set in resetLimiters, may be nil bandwidthLimiter: nil, // Set in resetLimiters accountLimiter: nil, // Set in resetLimiters, may be nil authLimiter: nil, // Set in resetLimiters, may be nil } - v.resetLimitersNoLock(messages, emails, false) + v.resetLimitersNoLock(messages, emails, sms, calls, false) return v } @@ -147,12 +159,22 @@ func (v *visitor) contextNoLock() log.Context { "visitor_messages": info.Stats.Messages, "visitor_messages_limit": info.Limits.MessageLimit, "visitor_messages_remaining": info.Stats.MessagesRemaining, - "visitor_emails": info.Stats.Emails, - "visitor_emails_limit": info.Limits.EmailLimit, - "visitor_emails_remaining": info.Stats.EmailsRemaining, "visitor_request_limiter_limit": v.requestLimiter.Limit(), "visitor_request_limiter_tokens": v.requestLimiter.Tokens(), } + if v.config.SMTPSenderFrom != "" { + fields["visitor_emails"] = info.Stats.Emails + fields["visitor_emails_limit"] = info.Limits.EmailLimit + fields["visitor_emails_remaining"] = info.Stats.EmailsRemaining + } + if v.config.TwilioAccount != "" { + fields["visitor_sms"] = info.Stats.SMS + fields["visitor_sms_limit"] = info.Limits.SMSLimit + fields["visitor_sms_remaining"] = info.Stats.SMSRemaining + fields["visitor_calls"] = info.Stats.Calls + fields["visitor_calls_limit"] = info.Limits.CallLimit + fields["visitor_calls_remaining"] = info.Stats.CallsRemaining + } if v.authLimiter != nil { fields["visitor_auth_limiter_limit"] = v.authLimiter.Limit() fields["visitor_auth_limiter_tokens"] = v.authLimiter.Tokens() @@ -216,6 +238,18 @@ func (v *visitor) EmailAllowed() bool { return v.emailsLimiter.Allow() } +func (v *visitor) SMSAllowed() bool { + v.mu.RLock() // limiters could be replaced! + defer v.mu.RUnlock() + return v.smsLimiter.Allow() +} + +func (v *visitor) CallAllowed() bool { + v.mu.RLock() // limiters could be replaced! + defer v.mu.RUnlock() + return v.callsLimiter.Allow() +} + func (v *visitor) SubscriptionAllowed() bool { v.mu.RLock() // limiters could be replaced! defer v.mu.RUnlock() @@ -296,6 +330,8 @@ func (v *visitor) Stats() *user.Stats { return &user.Stats{ Messages: v.messagesLimiter.Value(), Emails: v.emailsLimiter.Value(), + SMS: v.smsLimiter.Value(), + Calls: v.callsLimiter.Value(), } } @@ -304,6 +340,8 @@ func (v *visitor) ResetStats() { defer v.mu.RUnlock() v.emailsLimiter.Reset() v.messagesLimiter.Reset() + v.smsLimiter.Reset() + v.callsLimiter.Reset() } // User returns the visitor user, or nil if there is none @@ -334,11 +372,11 @@ func (v *visitor) SetUser(u *user.User) { shouldResetLimiters := v.user.TierID() != u.TierID() // TierID works with nil receiver v.user = u // u may be nil! if shouldResetLimiters { - var messages, emails int64 + var messages, emails, sms, calls int64 if u != nil { - messages, emails = u.Stats.Messages, u.Stats.Emails + messages, emails, sms, calls = u.Stats.Messages, u.Stats.Emails, u.Stats.SMS, u.Stats.Calls } - v.resetLimitersNoLock(messages, emails, true) + v.resetLimitersNoLock(messages, emails, sms, calls, true) } } @@ -353,11 +391,13 @@ func (v *visitor) MaybeUserID() string { return "" } -func (v *visitor) resetLimitersNoLock(messages, emails int64, enqueueUpdate bool) { +func (v *visitor) resetLimitersNoLock(messages, emails, sms, calls int64, enqueueUpdate bool) { limits := v.limitsNoLock() v.requestLimiter = rate.NewLimiter(limits.RequestLimitReplenish, limits.RequestLimitBurst) v.messagesLimiter = util.NewFixedLimiterWithValue(limits.MessageLimit, messages) v.emailsLimiter = util.NewRateLimiterWithValue(limits.EmailLimitReplenish, limits.EmailLimitBurst, emails) + v.smsLimiter = util.NewFixedLimiterWithValue(limits.SMSLimit, sms) + v.callsLimiter = util.NewFixedLimiterWithValue(limits.CallLimit, calls) v.bandwidthLimiter = util.NewBytesLimiter(int(limits.AttachmentBandwidthLimit), oneDay) if v.user == nil { v.accountLimiter = rate.NewLimiter(rate.Every(v.config.VisitorAccountCreationLimitReplenish), v.config.VisitorAccountCreationLimitBurst) @@ -370,6 +410,8 @@ func (v *visitor) resetLimitersNoLock(messages, emails int64, enqueueUpdate bool go v.userManager.EnqueueUserStats(v.user.ID, &user.Stats{ Messages: messages, Emails: emails, + SMS: sms, + Calls: calls, }) } log.Fields(v.contextNoLock()).Debug("Rate limiters reset for visitor") // Must be after function, because contextNoLock() describes rate limiters @@ -398,6 +440,8 @@ func tierBasedVisitorLimits(conf *Config, tier *user.Tier) *visitorLimits { EmailLimit: tier.EmailLimit, EmailLimitBurst: util.MinMax(int(float64(tier.EmailLimit)*visitorEmailLimitBurstRate), conf.VisitorEmailLimitBurst, visitorEmailLimitBurstMax), EmailLimitReplenish: dailyLimitToRate(tier.EmailLimit), + SMSLimit: tier.SMSLimit, + CallLimit: tier.CallLimit, ReservationsLimit: tier.ReservationLimit, AttachmentTotalSizeLimit: tier.AttachmentTotalSizeLimit, AttachmentFileSizeLimit: tier.AttachmentFileSizeLimit, @@ -420,6 +464,8 @@ func configBasedVisitorLimits(conf *Config) *visitorLimits { EmailLimit: replenishDurationToDailyLimit(conf.VisitorEmailLimitReplenish), // Approximation! EmailLimitBurst: conf.VisitorEmailLimitBurst, EmailLimitReplenish: rate.Every(conf.VisitorEmailLimitReplenish), + SMSLimit: int64(conf.VisitorSMSDailyLimit), + CallLimit: int64(conf.VisitorCallDailyLimit), ReservationsLimit: visitorDefaultReservationsLimit, AttachmentTotalSizeLimit: conf.VisitorAttachmentTotalSizeLimit, AttachmentFileSizeLimit: conf.AttachmentFileSizeLimit, @@ -465,12 +511,18 @@ func (v *visitor) Info() (*visitorInfo, error) { func (v *visitor) infoLightNoLock() *visitorInfo { messages := v.messagesLimiter.Value() emails := v.emailsLimiter.Value() + sms := v.smsLimiter.Value() + calls := v.callsLimiter.Value() limits := v.limitsNoLock() stats := &visitorStats{ Messages: messages, MessagesRemaining: zeroIfNegative(limits.MessageLimit - messages), Emails: emails, EmailsRemaining: zeroIfNegative(limits.EmailLimit - emails), + SMS: sms, + SMSRemaining: zeroIfNegative(limits.SMSLimit - sms), + Calls: calls, + CallsRemaining: zeroIfNegative(limits.CallLimit - calls), } return &visitorInfo{ Limits: limits, diff --git a/user/manager.go b/user/manager.go index b2898ae8..3effd5cd 100644 --- a/user/manager.go +++ b/user/manager.go @@ -55,6 +55,8 @@ const ( messages_limit INT NOT NULL, messages_expiry_duration INT NOT NULL, emails_limit INT NOT NULL, + sms_limit INT NOT NULL, + calls_limit INT NOT NULL, reservations_limit INT NOT NULL, attachment_file_size_limit INT NOT NULL, attachment_total_size_limit INT NOT NULL, @@ -76,6 +78,8 @@ const ( sync_topic TEXT NOT NULL, stats_messages INT NOT NULL DEFAULT (0), stats_emails INT NOT NULL DEFAULT (0), + stats_sms INT NOT NULL DEFAULT (0), + stats_calls INT NOT NULL DEFAULT (0), stripe_customer_id TEXT, stripe_subscription_id TEXT, stripe_subscription_status TEXT, @@ -123,26 +127,26 @@ const ( ` selectUserByIDQuery = ` - SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id + SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stats_sms, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id FROM user u LEFT JOIN tier t on t.id = u.tier_id WHERE u.id = ? ` selectUserByNameQuery = ` - SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id + SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stats_sms, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id FROM user u LEFT JOIN tier t on t.id = u.tier_id WHERE user = ? ` selectUserByTokenQuery = ` - SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id + SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stats_sms, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id FROM user u JOIN user_token tk on u.id = tk.user_id LEFT JOIN tier t on t.id = u.tier_id WHERE tk.token = ? AND (tk.expires = 0 OR tk.expires >= ?) ` selectUserByStripeCustomerIDQuery = ` - SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id + SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stats_sms, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id FROM user u LEFT JOIN tier t on t.id = u.tier_id WHERE u.stripe_customer_id = ? @@ -173,8 +177,8 @@ const ( updateUserPassQuery = `UPDATE user SET pass = ? WHERE user = ?` updateUserRoleQuery = `UPDATE user SET role = ? WHERE user = ?` updateUserPrefsQuery = `UPDATE user SET prefs = ? WHERE id = ?` - updateUserStatsQuery = `UPDATE user SET stats_messages = ?, stats_emails = ? WHERE id = ?` - updateUserStatsResetAllQuery = `UPDATE user SET stats_messages = 0, stats_emails = 0` + updateUserStatsQuery = `UPDATE user SET stats_messages = ?, stats_emails = ?, stats_sms = ?, stats_calls = ? WHERE id = ?` + updateUserStatsResetAllQuery = `UPDATE user SET stats_messages = 0, stats_emails = 0, stats_sms = 0, stats_calls = 0` updateUserDeletedQuery = `UPDATE user SET deleted = ? WHERE id = ?` deleteUsersMarkedQuery = `DELETE FROM user WHERE deleted < ?` deleteUserQuery = `DELETE FROM user WHERE user = ?` @@ -258,25 +262,25 @@ const ( ` insertTierQuery = ` - INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, sms_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ` updateTierQuery = ` UPDATE tier - SET name = ?, messages_limit = ?, messages_expiry_duration = ?, emails_limit = ?, reservations_limit = ?, attachment_file_size_limit = ?, attachment_total_size_limit = ?, attachment_expiry_duration = ?, attachment_bandwidth_limit = ?, stripe_monthly_price_id = ?, stripe_yearly_price_id = ? + SET name = ?, messages_limit = ?, messages_expiry_duration = ?, emails_limit = ?, sms_limit = ?, calls_limit = ?, reservations_limit = ?, attachment_file_size_limit = ?, attachment_total_size_limit = ?, attachment_expiry_duration = ?, attachment_bandwidth_limit = ?, stripe_monthly_price_id = ?, stripe_yearly_price_id = ? WHERE code = ? ` selectTiersQuery = ` - SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id + SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, sms_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id FROM tier ` selectTierByCodeQuery = ` - SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id + SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, sms_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id FROM tier WHERE code = ? ` selectTierByPriceIDQuery = ` - SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id + SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, sms_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id FROM tier WHERE (stripe_monthly_price_id = ? OR stripe_yearly_price_id = ?) ` @@ -293,7 +297,7 @@ const ( // Schema management queries const ( - currentSchemaVersion = 3 + currentSchemaVersion = 4 insertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)` updateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1` selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1` @@ -391,12 +395,21 @@ const ( CREATE UNIQUE INDEX idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id); CREATE UNIQUE INDEX idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id); ` + + // 3 -> 4 + migrate3To4UpdateQueries = ` + ALTER TABLE tier ADD COLUMN sms_limit INT NOT NULL DEFAULT (0); + ALTER TABLE tier ADD COLUMN calls_limit INT NOT NULL DEFAULT (0); + ALTER TABLE user ADD COLUMN stats_sms INT NOT NULL DEFAULT (0); + ALTER TABLE user ADD COLUMN stats_calls INT NOT NULL DEFAULT (0); + ` ) var ( migrations = map[int]func(db *sql.DB) error{ 1: migrateFrom1, 2: migrateFrom2, + 3: migrateFrom3, } ) @@ -700,9 +713,11 @@ func (a *Manager) writeUserStatsQueue() error { "user_id": userID, "messages_count": update.Messages, "emails_count": update.Emails, + "sms_count": update.SMS, + "calls_count": update.Calls, }). Trace("Updating stats for user %s", userID) - if _, err := tx.Exec(updateUserStatsQuery, update.Messages, update.Emails, userID); err != nil { + if _, err := tx.Exec(updateUserStatsQuery, update.Messages, update.Emails, update.SMS, update.Calls, userID); err != nil { return err } } @@ -911,12 +926,12 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) { defer rows.Close() var id, username, hash, role, prefs, syncTopic string var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripeSubscriptionInterval, stripeMonthlyPriceID, stripeYearlyPriceID, tierID, tierCode, tierName sql.NullString - var messages, emails int64 + var messages, emails, sms, calls int64 var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64 if !rows.Next() { return nil, ErrUserNotFound } - if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionInterval, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierID, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil { + if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &sms, &calls, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionInterval, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierID, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil { return nil, err } else if err := rows.Err(); err != nil { return nil, err @@ -931,6 +946,8 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) { Stats: &Stats{ Messages: messages, Emails: emails, + SMS: sms, + Calls: calls, }, Billing: &Billing{ StripeCustomerID: stripeCustomerID.String, // May be empty @@ -1259,7 +1276,7 @@ func (a *Manager) AddTier(tier *Tier) error { if tier.ID == "" { tier.ID = util.RandomStringPrefix(tierIDPrefix, tierIDLength) } - if _, err := a.db.Exec(insertTierQuery, tier.ID, tier.Code, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripeMonthlyPriceID), nullString(tier.StripeYearlyPriceID)); err != nil { + if _, err := a.db.Exec(insertTierQuery, tier.ID, tier.Code, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.SMSLimit, tier.CallLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripeMonthlyPriceID), nullString(tier.StripeYearlyPriceID)); err != nil { return err } return nil @@ -1267,7 +1284,7 @@ func (a *Manager) AddTier(tier *Tier) error { // UpdateTier updates a tier's properties in the database func (a *Manager) UpdateTier(tier *Tier) error { - if _, err := a.db.Exec(updateTierQuery, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripeMonthlyPriceID), nullString(tier.StripeYearlyPriceID), tier.Code); err != nil { + if _, err := a.db.Exec(updateTierQuery, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.SMSLimit, tier.CallLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripeMonthlyPriceID), nullString(tier.StripeYearlyPriceID), tier.Code); err != nil { return err } return nil @@ -1336,11 +1353,11 @@ func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) { func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) { var id, code, name string var stripeMonthlyPriceID, stripeYearlyPriceID sql.NullString - var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit sql.NullInt64 + var messagesLimit, messagesExpiryDuration, emailsLimit, smsLimit, callsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit sql.NullInt64 if !rows.Next() { return nil, ErrTierNotFound } - if err := rows.Scan(&id, &code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil { + if err := rows.Scan(&id, &code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &smsLimit, &callsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil { return nil, err } else if err := rows.Err(); err != nil { return nil, err @@ -1353,6 +1370,8 @@ func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) { MessageLimit: messagesLimit.Int64, MessageExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second, EmailLimit: emailsLimit.Int64, + SMSLimit: smsLimit.Int64, + CallLimit: callsLimit.Int64, ReservationLimit: reservationsLimit.Int64, AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64, AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64, @@ -1495,6 +1514,22 @@ func migrateFrom2(db *sql.DB) error { return tx.Commit() } +func migrateFrom3(db *sql.DB) error { + log.Tag(tag).Info("Migrating user database schema: from 3 to 4") + tx, err := db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + if _, err := tx.Exec(migrate3To4UpdateQueries); err != nil { + return err + } + if _, err := tx.Exec(updateSchemaVersion, 4); err != nil { + return err + } + return tx.Commit() +} + func nullString(s string) sql.NullString { if s == "" { return sql.NullString{} diff --git a/user/types.go b/user/types.go index 2486f110..6340229b 100644 --- a/user/types.go +++ b/user/types.go @@ -86,6 +86,8 @@ type Tier struct { MessageLimit int64 // Daily message limit MessageExpiryDuration time.Duration // Cache duration for messages EmailLimit int64 // Daily email limit + SMSLimit int64 // Daily SMS limit + CallLimit int64 // Daily phone call limit ReservationLimit int64 // Number of topic reservations allowed by user AttachmentFileSizeLimit int64 // Max file size per file (bytes) AttachmentTotalSizeLimit int64 // Total file size for all files of this user (bytes) @@ -131,6 +133,8 @@ type NotificationPrefs struct { type Stats struct { Messages int64 Emails int64 + SMS int64 + Calls int64 } // Billing is a struct holding a user's billing information