WIP: Stripe integration

pull/584/head
binwiederhier 2023-01-14 06:43:44 -05:00
parent 7007c0a0bd
commit 01fd4754f9
20 changed files with 557 additions and 43 deletions

View File

@ -103,7 +103,7 @@ func changeAccess(c *cli.Context, manager *user.Manager, username string, topic
read := util.Contains([]string{"read-write", "rw", "read-only", "read", "ro"}, perms)
write := util.Contains([]string{"read-write", "rw", "write-only", "write", "wo"}, perms)
u, err := manager.User(username)
if err == user.ErrNotFound {
if err == user.ErrUserNotFound {
return fmt.Errorf("user %s does not exist", username)
} else if u.Role == user.RoleAdmin {
return fmt.Errorf("user %s is an admin user, access control entries have no effect", username)
@ -173,7 +173,7 @@ func showAllAccess(c *cli.Context, manager *user.Manager) error {
func showUserAccess(c *cli.Context, manager *user.Manager, username string) error {
users, err := manager.User(username)
if err == user.ErrNotFound {
if err == user.ErrUserNotFound {
return fmt.Errorf("user %s does not exist", username)
} else if err != nil {
return err

View File

@ -5,6 +5,7 @@ package cmd
import (
"errors"
"fmt"
"github.com/stripe/stripe-go/v74"
"heckel.io/ntfy/user"
"io/fs"
"math"
@ -61,7 +62,6 @@ var flagsServe = append(
altsrc.NewBoolFlag(&cli.BoolFlag{Name: "enable-signup", Aliases: []string{"enable_signup"}, EnvVars: []string{"NTFY_ENABLE_SIGNUP"}, Value: false, Usage: "allows users to sign up via the web app, or API"}),
altsrc.NewBoolFlag(&cli.BoolFlag{Name: "enable-login", Aliases: []string{"enable_login"}, EnvVars: []string{"NTFY_ENABLE_LOGIN"}, Value: false, Usage: "allows users to log in via the web app, or API"}),
altsrc.NewBoolFlag(&cli.BoolFlag{Name: "enable-reservations", Aliases: []string{"enable_reservations"}, EnvVars: []string{"NTFY_ENABLE_RESERVATIONS"}, Value: false, Usage: "allows users to reserve topics (if their tier allows it)"}),
altsrc.NewBoolFlag(&cli.BoolFlag{Name: "enable-payments", Aliases: []string{"enable_payments"}, EnvVars: []string{"NTFY_ENABLE_PAYMENTS"}, Value: false, Usage: "enables payments integration [preliminary option, may change]"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "upstream-base-url", Aliases: []string{"upstream_base_url"}, EnvVars: []string{"NTFY_UPSTREAM_BASE_URL"}, Value: "", Usage: "forward poll request to an upstream server, this is needed for iOS push notifications for self-hosted servers"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "smtp-sender-addr", Aliases: []string{"smtp_sender_addr"}, EnvVars: []string{"NTFY_SMTP_SENDER_ADDR"}, Usage: "SMTP server address (host:port) for outgoing emails"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "smtp-sender-user", Aliases: []string{"smtp_sender_user"}, EnvVars: []string{"NTFY_SMTP_SENDER_USER"}, Usage: "SMTP user (if e-mail sending is enabled)"}),
@ -80,6 +80,8 @@ var flagsServe = append(
altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-email-limit-burst", Aliases: []string{"visitor_email_limit_burst"}, EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_BURST"}, Value: server.DefaultVisitorEmailLimitBurst, Usage: "initial limit of e-mails per visitor"}),
altsrc.NewDurationFlag(&cli.DurationFlag{Name: "visitor-email-limit-replenish", Aliases: []string{"visitor_email_limit_replenish"}, EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_REPLENISH"}, Value: server.DefaultVisitorEmailLimitReplenish, Usage: "interval at which burst limit is replenished (one per x)"}),
altsrc.NewBoolFlag(&cli.BoolFlag{Name: "behind-proxy", Aliases: []string{"behind_proxy", "P"}, EnvVars: []string{"NTFY_BEHIND_PROXY"}, Value: false, Usage: "if set, use X-Forwarded-For header to determine visitor IP address (for rate limiting)"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "stripe-key", Aliases: []string{"stripe_key"}, EnvVars: []string{"NTFY_STRIPE_KEY"}, Value: "", Usage: "xxxxxxxxxxxxx"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "stripe-webhook-key", Aliases: []string{"stripe_webhook_key"}, EnvVars: []string{"NTFY_STRIPE_WEBHOOK_KEY"}, Value: "", Usage: "xxxxxxxxxxxx"}),
)
var cmdServe = &cli.Command{
@ -132,7 +134,6 @@ func execServe(c *cli.Context) error {
webRoot := c.String("web-root")
enableSignup := c.Bool("enable-signup")
enableLogin := c.Bool("enable-login")
enablePayments := c.Bool("enable-payments")
enableReservations := c.Bool("enable-reservations")
upstreamBaseURL := c.String("upstream-base-url")
smtpSenderAddr := c.String("smtp-sender-addr")
@ -152,6 +153,8 @@ func execServe(c *cli.Context) error {
visitorEmailLimitBurst := c.Int("visitor-email-limit-burst")
visitorEmailLimitReplenish := c.Duration("visitor-email-limit-replenish")
behindProxy := c.Bool("behind-proxy")
stripeKey := c.String("stripe-key")
stripeWebhookKey := c.String("stripe-webhook-key")
// Check values
if firebaseKeyFile != "" && !util.FileExists(firebaseKeyFile) {
@ -188,14 +191,17 @@ func execServe(c *cli.Context) error {
return errors.New("if upstream-base-url is set, base-url must also be set")
} else if upstreamBaseURL != "" && baseURL != "" && baseURL == upstreamBaseURL {
return errors.New("base-url and upstream-base-url cannot be identical, you'll likely want to set upstream-base-url to https://ntfy.sh, see https://ntfy.sh/docs/config/#ios-instant-notifications")
} else if authFile == "" && (enableSignup || enableLogin || enableReservations || enablePayments) {
return errors.New("cannot set enable-signup, enable-login, enable-reserve-topics, or enable-payments if auth-file is not set")
} else if authFile == "" && (enableSignup || enableLogin || enableReservations || stripeKey != "") {
return errors.New("cannot set enable-signup, enable-login, enable-reserve-topics, or stripe-key if auth-file is not set")
} else if enableSignup && !enableLogin {
return errors.New("cannot set enable-signup without also setting enable-login")
} else if stripeKey != "" && (stripeWebhookKey == "" || baseURL == "") {
return errors.New("if stripe-key is set, stripe-webhook-key and base-url must also be set")
}
webRootIsApp := webRoot == "app"
enableWeb := webRoot != "disable"
enablePayments := stripeKey != ""
// Default auth permissions
authDefault, err := user.ParsePermission(authDefaultAccess)
@ -239,6 +245,11 @@ func execServe(c *cli.Context) error {
visitorRequestLimitExemptIPs = append(visitorRequestLimitExemptIPs, ips...)
}
// Stripe things
if stripeKey != "" {
stripe.Key = stripeKey
}
// Run server
conf := server.NewConfig()
conf.BaseURL = baseURL
@ -282,6 +293,8 @@ func execServe(c *cli.Context) error {
conf.VisitorEmailLimitBurst = visitorEmailLimitBurst
conf.VisitorEmailLimitReplenish = visitorEmailLimitReplenish
conf.BehindProxy = behindProxy
conf.StripeKey = stripeKey
conf.StripeWebhookKey = stripeWebhookKey
conf.EnableWeb = enableWeb
conf.EnableSignup = enableSignup
conf.EnableLogin = enableLogin

View File

@ -215,7 +215,7 @@ func execUserDel(c *cli.Context) error {
if err != nil {
return err
}
if _, err := manager.User(username); err == user.ErrNotFound {
if _, err := manager.User(username); err == user.ErrUserNotFound {
return fmt.Errorf("user %s does not exist", username)
}
if err := manager.RemoveUser(username); err != nil {
@ -237,7 +237,7 @@ func execUserChangePass(c *cli.Context) error {
if err != nil {
return err
}
if _, err := manager.User(username); err == user.ErrNotFound {
if _, err := manager.User(username); err == user.ErrUserNotFound {
return fmt.Errorf("user %s does not exist", username)
}
if password == "" {
@ -265,7 +265,7 @@ func execUserChangeRole(c *cli.Context) error {
if err != nil {
return err
}
if _, err := manager.User(username); err == user.ErrNotFound {
if _, err := manager.User(username); err == user.ErrUserNotFound {
return fmt.Errorf("user %s does not exist", username)
}
if err := manager.ChangeRole(username, role); err != nil {
@ -289,7 +289,7 @@ func execUserChangeTier(c *cli.Context) error {
if err != nil {
return err
}
if _, err := manager.User(username); err == user.ErrNotFound {
if _, err := manager.User(username); err == user.ErrUserNotFound {
return fmt.Errorf("user %s does not exist", username)
}
if tier == tierReset {

4
go.mod
View File

@ -46,6 +46,10 @@ require (
github.com/googleapis/gax-go/v2 v2.7.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/russross/blackfriday/v2 v2.1.0 // indirect
github.com/stripe/stripe-go/v74 v74.5.0 // indirect
github.com/tidwall/gjson v1.14.4 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect
go.opencensus.io v0.24.0 // indirect
golang.org/x/net v0.4.0 // indirect

12
go.sum
View File

@ -95,10 +95,20 @@ github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQD
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stripe/stripe-go/v74 v74.5.0 h1:YyqTvVQdS34KYGCfVB87EMn9eDV3FCFkSwfdOQhiVL4=
github.com/stripe/stripe-go/v74 v74.5.0/go.mod h1:5PoXNp30AJ3tGq57ZcFuaMylzNi8KpwlrYAFmO1fHZw=
github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM=
github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/urfave/cli/v2 v2.23.7 h1:YHDQ46s3VghFHFf1DdF+Sh7H4RqhcM+t0TmZRJx4oJY=
github.com/urfave/cli/v2 v2.23.7/go.mod h1:GHupkWPMM0M/sj1a2b4wUrWBPzazNrIjouW6fmdJLxc=
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 h1:bAn7/zixMGCfxrRTfdpNzjtPYqr8smhKouy9mxVdGPU=
@ -119,6 +129,7 @@ golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73r
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.0.0-20220708220712-1185a9018129/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
@ -135,6 +146,7 @@ golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

View File

@ -110,6 +110,8 @@ type Config struct {
VisitorAccountCreateLimitReplenish time.Duration
VisitorStatsResetTime time.Time // Time of the day at which to reset visitor stats
BehindProxy bool
StripeKey string
StripeWebhookKey string
EnableWeb bool
EnableSignup bool // Enable creation of accounts via API and UI
EnableLogin bool

View File

@ -58,6 +58,8 @@ var (
errHTTPBadRequestJSONInvalid = &errHTTP{40024, http.StatusBadRequest, "invalid request: request body must be valid JSON", ""}
errHTTPBadRequestPermissionInvalid = &errHTTP{40025, http.StatusBadRequest, "invalid request: incorrect permission string", ""}
errHTTPBadRequestMakesNoSenseForAdmin = &errHTTP{40026, http.StatusBadRequest, "invalid request: this makes no sense for admins", ""}
errHTTPBadRequestNotAPaidUser = &errHTTP{40027, http.StatusBadRequest, "invalid request: not a paid user", ""}
errHTTPBadRequestInvalidStripeRequest = &errHTTP{40028, http.StatusBadRequest, "invalid request: not a valid Stripe request", ""}
errHTTPNotFound = &errHTTP{40401, http.StatusNotFound, "page not found", ""}
errHTTPUnauthorized = &errHTTP{40101, http.StatusUnauthorized, "unauthorized", "https://ntfy.sh/docs/publish/#authentication"}
errHTTPForbidden = &errHTTP{40301, http.StatusForbidden, "forbidden", "https://ntfy.sh/docs/publish/#authentication"}

View File

@ -36,6 +36,10 @@ import (
/*
TODO
payments:
- handle overdue payment (-> downgrade after 7 days)
- delete stripe subscription when acocunt is deleted
Limits & rate limiting:
users without tier: should the stats be persisted? are they meaningful?
-> test that the visitor is based on the IP address!
@ -43,6 +47,7 @@ import (
update last_seen when API is accessed
Make sure account endpoints make sense for admins
triggerChange after publishing a message
UI:
- flicker of upgrade banner
- JS constants
@ -100,6 +105,11 @@ var (
accountSettingsPath = "/v1/account/settings"
accountSubscriptionPath = "/v1/account/subscription"
accountReservationPath = "/v1/account/reservation"
accountBillingPortalPath = "/v1/account/billing/portal"
accountBillingWebhookPath = "/v1/account/billing/webhook"
accountCheckoutPath = "/v1/account/checkout"
accountCheckoutSuccessTemplate = "/v1/account/checkout/success/{CHECKOUT_SESSION_ID}"
accountCheckoutSuccessRegex = regexp.MustCompile(`/v1/account/checkout/success/(.+)$`)
accountReservationSingleRegex = regexp.MustCompile(`/v1/account/reservation/([-_A-Za-z0-9]{1,64})$`)
accountSubscriptionSingleRegex = regexp.MustCompile(`^/v1/account/subscription/([-_A-Za-z0-9]{16})$`)
matrixPushPath = "/_matrix/push/v1/notify"
@ -362,6 +372,14 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit
return s.ensureUser(s.handleAccountReservationAdd)(w, r, v)
} else if r.Method == http.MethodDelete && accountReservationSingleRegex.MatchString(r.URL.Path) {
return s.ensureUser(s.handleAccountReservationDelete)(w, r, v)
} else if r.Method == http.MethodPost && r.URL.Path == accountCheckoutPath {
return s.ensureUser(s.handleAccountCheckoutSessionCreate)(w, r, v)
} else if r.Method == http.MethodGet && accountCheckoutSuccessRegex.MatchString(r.URL.Path) {
return s.ensureUserManager(s.handleAccountCheckoutSessionSuccessGet)(w, r, v) // No user context!
} else if r.Method == http.MethodPost && r.URL.Path == accountBillingPortalPath {
return s.ensureUser(s.handleAccountBillingPortalSessionCreate)(w, r, v)
} else if r.Method == http.MethodPost && r.URL.Path == accountBillingWebhookPath {
return s.ensureUserManager(s.handleAccountBillingWebhookTrigger)(w, r, v)
} else if r.Method == http.MethodGet && r.URL.Path == matrixPushPath {
return s.handleMatrixDiscovery(w)
} else if r.Method == http.MethodGet && staticRegex.MatchString(r.URL.Path) {

View File

@ -2,6 +2,14 @@ package server
import (
"encoding/json"
"errors"
"github.com/stripe/stripe-go/v74"
portalsession "github.com/stripe/stripe-go/v74/billingportal/session"
"github.com/stripe/stripe-go/v74/checkout/session"
"github.com/stripe/stripe-go/v74/subscription"
"github.com/stripe/stripe-go/v74/webhook"
"github.com/tidwall/gjson"
"heckel.io/ntfy/log"
"heckel.io/ntfy/user"
"heckel.io/ntfy/util"
"net/http"
@ -9,6 +17,7 @@ import (
const (
jsonBodyBytesLimit = 4096
stripeBodyBytesLimit = 16384
subscriptionIDLength = 16
createdByAPI = "api"
)
@ -386,3 +395,226 @@ func (s *Server) handleAccountReservationDelete(w http.ResponseWriter, r *http.R
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
return nil
}
func (s *Server) handleAccountCheckoutSessionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
req, err := readJSONWithLimit[apiAccountTierChangeRequest](r.Body, jsonBodyBytesLimit)
if err != nil {
return err
}
tier, err := s.userManager.Tier(req.Tier)
if err != nil {
return err
}
if tier.StripePriceID == "" {
log.Info("Checkout: Downgrading to no tier")
return errors.New("not a paid tier")
} else if v.user.Billing != nil && v.user.Billing.StripeSubscriptionID != "" {
log.Info("Checkout: Changing tier and subscription to %s", tier.Code)
// Upgrade/downgrade tier
sub, err := subscription.Get(v.user.Billing.StripeSubscriptionID, nil)
if err != nil {
return err
}
params := &stripe.SubscriptionParams{
CancelAtPeriodEnd: stripe.Bool(false),
ProrationBehavior: stripe.String(string(stripe.SubscriptionSchedulePhaseProrationBehaviorCreateProrations)),
Items: []*stripe.SubscriptionItemsParams{
{
ID: stripe.String(sub.Items.Data[0].ID),
Price: stripe.String(tier.StripePriceID),
},
},
}
_, err = subscription.Update(sub.ID, params)
if err != nil {
return err
}
response := &apiAccountCheckoutResponse{}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
if err := json.NewEncoder(w).Encode(response); err != nil {
return err
}
return nil
} else {
// Checkout flow
log.Info("Checkout: No existing subscription, creating checkout flow")
}
successURL := s.config.BaseURL + accountCheckoutSuccessTemplate
var stripeCustomerID *string
if v.user.Billing != nil {
stripeCustomerID = &v.user.Billing.StripeCustomerID
}
params := &stripe.CheckoutSessionParams{
ClientReferenceID: &v.user.Name, // FIXME Should be user ID
Customer: stripeCustomerID,
SuccessURL: &successURL,
Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)),
LineItems: []*stripe.CheckoutSessionLineItemParams{
{
Price: stripe.String(tier.StripePriceID),
Quantity: stripe.Int64(1),
},
},
}
sess, err := session.New(params)
if err != nil {
return err
}
response := &apiAccountCheckoutResponse{
RedirectURL: sess.URL,
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
if err := json.NewEncoder(w).Encode(response); err != nil {
return err
}
return nil
}
func (s *Server) handleAccountCheckoutSessionSuccessGet(w http.ResponseWriter, r *http.Request, v *visitor) error {
// We don't have a v.user in this endpoint, only a userManager!
matches := accountCheckoutSuccessRegex.FindStringSubmatch(r.URL.Path)
if len(matches) != 2 {
return errHTTPInternalErrorInvalidPath
}
sessionID := matches[1]
// FIXME how do I rate limit this?
sess, err := session.Get(sessionID, nil)
if err != nil {
log.Warn("Stripe: %s", err)
return errHTTPBadRequestInvalidStripeRequest
} else if sess.Customer == nil || sess.Subscription == nil || sess.ClientReferenceID == "" {
log.Warn("Stripe: Unexpected session, customer or subscription not found")
return errHTTPBadRequestInvalidStripeRequest
}
sub, err := subscription.Get(sess.Subscription.ID, nil)
if err != nil {
return err
} else if sub.Items == nil || len(sub.Items.Data) != 1 || sub.Items.Data[0].Price == nil {
log.Error("Stripe: Unexpected subscription, expected exactly one line item")
return errHTTPBadRequestInvalidStripeRequest
}
priceID := sub.Items.Data[0].Price.ID
tier, err := s.userManager.TierByStripePrice(priceID)
if err != nil {
return err
}
u, err := s.userManager.User(sess.ClientReferenceID)
if err != nil {
return err
}
if u.Billing == nil {
u.Billing = &user.Billing{}
}
u.Billing.StripeCustomerID = sess.Customer.ID
u.Billing.StripeSubscriptionID = sess.Subscription.ID
if err := s.userManager.ChangeBilling(u); err != nil {
return err
}
if err := s.userManager.ChangeTier(u.Name, tier.Code); err != nil {
return err
}
accountURL := s.config.BaseURL + "/account" // FIXME
http.Redirect(w, r, accountURL, http.StatusSeeOther)
return nil
}
func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
if v.user.Billing == nil {
return errHTTPBadRequestNotAPaidUser
}
params := &stripe.BillingPortalSessionParams{
Customer: stripe.String(v.user.Billing.StripeCustomerID),
ReturnURL: stripe.String(s.config.BaseURL),
}
ps, err := portalsession.New(params)
if err != nil {
return err
}
response := &apiAccountBillingPortalRedirectResponse{
RedirectURL: ps.URL,
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
if err := json.NewEncoder(w).Encode(response); err != nil {
return err
}
return nil
}
func (s *Server) handleAccountBillingWebhookTrigger(w http.ResponseWriter, r *http.Request, v *visitor) error {
// We don't have a v.user in this endpoint, only a userManager!
stripeSignature := r.Header.Get("Stripe-Signature")
if stripeSignature == "" {
return errHTTPBadRequestInvalidStripeRequest
}
body, err := util.Peek(r.Body, stripeBodyBytesLimit)
if err != nil {
return err
} else if body.LimitReached {
return errHTTPEntityTooLargeJSONBody
}
event, err := webhook.ConstructEvent(body.PeekedBytes, stripeSignature, s.config.StripeWebhookKey)
if err != nil {
log.Warn("Stripe: invalid request: %s", err.Error())
return errHTTPBadRequestInvalidStripeRequest
} else if event.Data == nil || event.Data.Raw == nil {
log.Warn("Stripe: invalid request, data is nil")
return errHTTPBadRequestInvalidStripeRequest
}
log.Info("Stripe: webhook event %s received", event.Type)
stripeCustomerID := gjson.GetBytes(event.Data.Raw, "customer")
if !stripeCustomerID.Exists() {
return errHTTPBadRequestInvalidStripeRequest
}
switch event.Type {
case "checkout.session.completed":
// Payment is successful and the subscription is created.
// Provision the subscription, save the customer ID.
return s.handleAccountBillingWebhookCheckoutCompleted(stripeCustomerID.String(), event.Data.Raw)
case "customer.subscription.updated":
return s.handleAccountBillingWebhookSubscriptionUpdated(stripeCustomerID.String(), event.Data.Raw)
case "invoice.paid":
// Continue to provision the subscription as payments continue to be made.
// Store the status in your database and check when a user accesses your service.
// This approach helps you avoid hitting rate limits.
return nil // FIXME
case "invoice.payment_failed":
// The payment failed or the customer does not have a valid payment method.
// The subscription becomes past_due. Notify your customer and send them to the
// customer portal to update their payment information.
return nil // FIXME
default:
log.Warn("Stripe: unhandled webhook %s", event.Type)
return nil
}
}
func (s *Server) handleAccountBillingWebhookCheckoutCompleted(stripeCustomerID string, event json.RawMessage) error {
log.Info("Stripe: checkout completed for customer %s", stripeCustomerID)
return nil
}
func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(stripeCustomerID string, event json.RawMessage) error {
status := gjson.GetBytes(event, "status")
priceID := gjson.GetBytes(event, "items.data.0.price.id")
if !status.Exists() || !priceID.Exists() {
return errHTTPBadRequestInvalidStripeRequest
}
log.Info("Stripe: customer %s: subscription updated to %s, with price %s", stripeCustomerID, status, priceID)
u, err := s.userManager.UserByStripeCustomer(stripeCustomerID)
if err != nil {
return err
}
tier, err := s.userManager.TierByStripePrice(priceID.String())
if err != nil {
return err
}
if err := s.userManager.ChangeTier(u.Name, tier.Code); err != nil {
return err
}
return nil
}

View File

@ -295,3 +295,15 @@ type apiConfigResponse struct {
EnableReservations bool `json:"enable_reservations"`
DisallowedTopics []string `json:"disallowed_topics"`
}
type apiAccountTierChangeRequest struct {
Tier string `json:"tier"`
}
type apiAccountCheckoutResponse struct {
RedirectURL string `json:"redirect_url"`
}
type apiAccountBillingPortalRedirectResponse struct {
RedirectURL string `json:"redirect_url"`
}

View File

@ -44,8 +44,11 @@ const (
reservations_limit INT NOT NULL,
attachment_file_size_limit INT NOT NULL,
attachment_total_size_limit INT NOT NULL,
attachment_expiry_duration INT NOT NULL
attachment_expiry_duration INT NOT NULL,
stripe_price_id TEXT
);
CREATE UNIQUE INDEX idx_tier_code ON tier (code);
CREATE UNIQUE INDEX idx_tier_price_id ON tier (stripe_price_id);
CREATE TABLE IF NOT EXISTS user (
id INTEGER PRIMARY KEY AUTOINCREMENT,
tier_id INT,
@ -56,12 +59,16 @@ const (
sync_topic TEXT NOT NULL,
stats_messages INT NOT NULL DEFAULT (0),
stats_emails INT NOT NULL DEFAULT (0),
stripe_customer_id TEXT,
stripe_subscription_id TEXT,
created_by TEXT NOT NULL,
created_at INT NOT NULL,
last_seen INT NOT NULL,
FOREIGN KEY (tier_id) REFERENCES tier (id)
);
CREATE UNIQUE INDEX idx_user ON user (user);
CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
CREATE TABLE IF NOT EXISTS user_access (
user_id INT NOT NULL,
topic TEXT NOT NULL,
@ -93,18 +100,24 @@ const (
`
selectUserByNameQuery = `
SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration
SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id
FROM user u
LEFT JOIN tier p on p.id = u.tier_id
WHERE user = ?
`
selectUserByTokenQuery = `
SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration
SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id
FROM user u
JOIN user_token t on u.id = t.user_id
LEFT JOIN tier p on p.id = u.tier_id
WHERE t.token = ? AND t.expires >= ?
`
selectUserByStripeCustomerIDQuery = `
SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id
FROM user u
LEFT JOIN tier p on p.id = u.tier_id
WHERE u.stripe_customer_id = ?
`
selectTopicPermsQuery = `
SELECT read, write
FROM user_access a
@ -205,8 +218,20 @@ const (
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
selectTierIDQuery = `SELECT id FROM tier WHERE code = ?`
selectTierByCodeQuery = `
SELECT code, name, paid, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id
FROM tier
WHERE code = ?
`
selectTierByPriceIDQuery = `
SELECT code, name, paid, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id
FROM tier
WHERE stripe_price_id = ?
`
updateUserTierQuery = `UPDATE user SET tier_id = ? WHERE user = ?`
deleteUserTierQuery = `UPDATE user SET tier_id = null WHERE user = ?`
updateBillingQuery = `UPDATE user SET stripe_customer_id = ?, stripe_subscription_id = ? WHERE user = ?`
)
// Schema management queries
@ -543,7 +568,7 @@ func (a *Manager) Users() ([]*User, error) {
return users, nil
}
// User returns the user with the given username if it exists, or ErrNotFound otherwise.
// User returns the user with the given username if it exists, or ErrUserNotFound otherwise.
// You may also pass Everyone to retrieve the anonymous user and its Grant list.
func (a *Manager) User(username string) (*User, error) {
rows, err := a.db.Query(selectUserByNameQuery, username)
@ -553,6 +578,14 @@ func (a *Manager) User(username string) (*User, error) {
return a.readUser(rows)
}
func (a *Manager) UserByStripeCustomer(stripeCustomerID string) (*User, error) {
rows, err := a.db.Query(selectUserByStripeCustomerIDQuery, stripeCustomerID)
if err != nil {
return nil, err
}
return a.readUser(rows)
}
func (a *Manager) userByToken(token string) (*User, error) {
rows, err := a.db.Query(selectUserByTokenQuery, token, time.Now().Unix())
if err != nil {
@ -564,14 +597,14 @@ func (a *Manager) userByToken(token string) (*User, error) {
func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
defer rows.Close()
var username, hash, role, prefs, syncTopic string
var tierCode, tierName sql.NullString
var stripeCustomerID, stripeSubscriptionID, stripePriceID, tierCode, tierName sql.NullString
var paid sql.NullBool
var messages, emails int64
var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration sql.NullInt64
if !rows.Next() {
return nil, ErrNotFound
return nil, ErrUserNotFound
}
if err := rows.Scan(&username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &tierCode, &tierName, &paid, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration); err != nil {
if err := rows.Scan(&username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &tierCode, &tierName, &paid, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil {
return nil, err
} else if err := rows.Err(); err != nil {
return nil, err
@ -590,7 +623,14 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
if err := json.Unmarshal([]byte(prefs), user.Prefs); err != nil {
return nil, err
}
if stripeCustomerID.Valid && stripeSubscriptionID.Valid {
user.Billing = &Billing{
StripeCustomerID: stripeCustomerID.String,
StripeSubscriptionID: stripeSubscriptionID.String,
}
}
if tierCode.Valid {
// See readTier() when this is changed!
user.Tier = &Tier{
Code: tierCode.String,
Name: tierName.String,
@ -602,6 +642,7 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
StripePriceID: stripePriceID.String,
}
}
return user, nil
@ -826,6 +867,59 @@ func (a *Manager) CreateTier(tier *Tier) error {
return nil
}
func (a *Manager) ChangeBilling(user *User) error {
if _, err := a.db.Exec(updateBillingQuery, user.Billing.StripeCustomerID, user.Billing.StripeSubscriptionID, user.Name); err != nil {
return err
}
return nil
}
func (a *Manager) Tier(code string) (*Tier, error) {
rows, err := a.db.Query(selectTierByCodeQuery, code)
if err != nil {
return nil, err
}
return a.readTier(rows)
}
func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) {
rows, err := a.db.Query(selectTierByPriceIDQuery, priceID)
if err != nil {
return nil, err
}
return a.readTier(rows)
}
func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
defer rows.Close()
var code, name string
var stripePriceID sql.NullString
var paid bool
var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration sql.NullInt64
if !rows.Next() {
return nil, ErrTierNotFound
}
if err := rows.Scan(&code, &name, &paid, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil {
return nil, err
} else if err := rows.Err(); err != nil {
return nil, err
}
// When changed, note readUser() as well
return &Tier{
Code: code,
Name: name,
Paid: paid,
MessagesLimit: messagesLimit.Int64,
MessagesExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
EmailsLimit: emailsLimit.Int64,
ReservationsLimit: reservationsLimit.Int64,
AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
StripePriceID: stripePriceID.String, // May be empty!
}, nil
}
func toSQLWildcard(s string) string {
return strings.ReplaceAll(s, "*", "%")
}

View File

@ -208,7 +208,7 @@ func TestManager_UserManagement(t *testing.T) {
// Remove user
require.Nil(t, a.RemoveUser("ben"))
_, err = a.User("ben")
require.Equal(t, ErrNotFound, err)
require.Equal(t, ErrUserNotFound, err)
users, err = a.Users()
require.Nil(t, err)

View File

@ -16,6 +16,7 @@ type User struct {
Prefs *Prefs
Tier *Tier
Stats *Stats
Billing *Billing
SyncTopic string
Created time.Time
LastSeen time.Time
@ -58,6 +59,7 @@ type Tier struct {
AttachmentFileSizeLimit int64
AttachmentTotalSizeLimit int64
AttachmentExpiryDuration time.Duration
StripePriceID string
}
// Subscription represents a user's topic subscription
@ -81,6 +83,12 @@ type Stats struct {
Emails int64
}
// Billing is a struct holding a user's billing information
type Billing struct {
StripeCustomerID string
StripeSubscriptionID string
}
// Grant is a struct that represents an access control entry to a topic by a user
type Grant struct {
TopicPattern string // May include wildcard (*)
@ -212,5 +220,6 @@ var (
ErrUnauthenticated = errors.New("unauthenticated")
ErrUnauthorized = errors.New("unauthorized")
ErrInvalidArgument = errors.New("invalid argument")
ErrNotFound = errors.New("not found")
ErrUserNotFound = errors.New("user not found")
ErrTierNotFound = errors.New("tier not found")
)

View File

@ -8,7 +8,7 @@ import {
accountTokenUrl,
accountUrl, maybeWithAuth, topicUrl,
withBasicAuth,
withBearerAuth
withBearerAuth, accountCheckoutUrl, accountBillingPortalUrl
} from "./utils";
import session from "./Session";
import subscriptionManager from "./SubscriptionManager";
@ -228,7 +228,7 @@ class AccountApi {
this.triggerChange(); // Dangle!
}
async upsertAccess(topic, everyone) {
async upsertReservation(topic, everyone) {
const url = accountReservationUrl(config.base_url);
console.log(`[AccountApi] Upserting user access to topic ${topic}, everyone=${everyone}`);
const response = await fetch(url, {
@ -249,7 +249,7 @@ class AccountApi {
this.triggerChange(); // Dangle!
}
async deleteAccess(topic) {
async deleteReservation(topic) {
const url = accountReservationSingleUrl(config.base_url, topic);
console.log(`[AccountApi] Removing topic reservation ${url}`);
const response = await fetch(url, {
@ -264,6 +264,39 @@ class AccountApi {
this.triggerChange(); // Dangle!
}
async createCheckoutSession(tier) {
const url = accountCheckoutUrl(config.base_url);
console.log(`[AccountApi] Creating checkout session`);
const response = await fetch(url, {
method: "POST",
headers: withBearerAuth({}, session.token()),
body: JSON.stringify({
tier: tier
})
});
if (response.status === 401 || response.status === 403) {
throw new UnauthorizedError();
} else if (response.status !== 200) {
throw new Error(`Unexpected server response ${response.status}`);
}
return await response.json();
}
async createBillingPortalSession() {
const url = accountBillingPortalUrl(config.base_url);
console.log(`[AccountApi] Creating billing portal session`);
const response = await fetch(url, {
method: "POST",
headers: withBearerAuth({}, session.token())
});
if (response.status === 401 || response.status === 403) {
throw new UnauthorizedError();
} else if (response.status !== 200) {
throw new Error(`Unexpected server response ${response.status}`);
}
return await response.json();
}
async sync() {
try {
if (!session.token()) {

View File

@ -26,6 +26,8 @@ export const accountSubscriptionUrl = (baseUrl) => `${baseUrl}/v1/account/subscr
export const accountSubscriptionSingleUrl = (baseUrl, id) => `${baseUrl}/v1/account/subscription/${id}`;
export const accountReservationUrl = (baseUrl) => `${baseUrl}/v1/account/reservation`;
export const accountReservationSingleUrl = (baseUrl, topic) => `${baseUrl}/v1/account/reservation/${topic}`;
export const accountCheckoutUrl = (baseUrl) => `${baseUrl}/v1/account/checkout`;
export const accountBillingPortalUrl = (baseUrl) => `${baseUrl}/v1/account/billing/portal`;
export const shortUrl = (url) => url.replaceAll(/https?:\/\//g, "");
export const expandUrl = (url) => [`https://${url}`, `http://${url}`];
export const expandSecureUrl = (url) => `https://${url}`;

View File

@ -171,10 +171,28 @@ const Stats = () => {
const { t } = useTranslation();
const { account } = useContext(AccountContext);
const [upgradeDialogOpen, setUpgradeDialogOpen] = useState(false);
if (!account) {
return <></>;
}
const normalize = (value, max) => Math.min(value / max * 100, 100);
const normalize = (value, max) => {
return Math.min(value / max * 100, 100);
};
const handleManageBilling = async () => {
try {
const response = await accountApi.createBillingPortalSession();
window.location.href = response.redirect_url;
} catch (e) {
console.log(`[Account] Error changing password`, e);
if ((e instanceof UnauthorizedError)) {
session.resetAndRedirect(routes.login);
}
// TODO show error
}
};
return (
<Card sx={{p: 3}} aria-label={t("account_usage_title")}>
<Typography variant="h5" sx={{marginBottom: 2}}>
@ -201,12 +219,20 @@ const Stats = () => {
>{t("account_usage_tier_upgrade_button")}</Button>
}
{config.enable_payments && account.role === "user" && account.tier?.paid &&
<>
<Button
variant="outlined"
size="small"
onClick={() => setUpgradeDialogOpen(true)}
sx={{ml: 1}}
>{t("account_usage_tier_change_button")}</Button>
<Button
variant="outlined"
size="small"
onClick={handleManageBilling}
sx={{ml: 1}}
>Manage billing</Button>
</>
}
<UpgradeDialog
open={upgradeDialogOpen}

View File

@ -501,7 +501,7 @@ const Reservations = () => {
const handleDialogSubmit = async (reservation) => {
setDialogOpen(false);
try {
await accountApi.upsertAccess(reservation.topic, reservation.everyone);
await accountApi.upsertReservation(reservation.topic, reservation.everyone);
await accountApi.sync();
console.debug(`[Preferences] Added topic reservation`, reservation);
} catch (e) {
@ -557,7 +557,7 @@ const ReservationsTable = (props) => {
const handleDialogSubmit = async (reservation) => {
setDialogOpen(false);
try {
await accountApi.upsertAccess(reservation.topic, reservation.everyone);
await accountApi.upsertReservation(reservation.topic, reservation.everyone);
await accountApi.sync();
console.debug(`[Preferences] Added topic reservation`, reservation);
} catch (e) {
@ -568,7 +568,7 @@ const ReservationsTable = (props) => {
const handleDeleteClick = async (reservation) => {
try {
await accountApi.deleteAccess(reservation.topic);
await accountApi.deleteReservation(reservation.topic);
await accountApi.sync();
console.debug(`[Preferences] Deleted topic reservation`, reservation);
} catch (e) {

View File

@ -110,7 +110,7 @@ const SubscribePage = (props) => {
if (session.exists() && baseUrl === config.base_url && reserveTopicVisible) {
console.log(`[SubscribeDialog] Reserving topic ${topic} with everyone access ${everyone}`);
try {
await accountApi.upsertAccess(topic, everyone);
await accountApi.upsertReservation(topic, everyone);
// Account sync later after it was added
} catch (e) {
console.log(`[SubscribeDialog] Error reserving topic`, e);

View File

@ -37,9 +37,9 @@ const SubscriptionSettingsDialog = (props) => {
// Reservation
if (reserveTopicVisible) {
await accountApi.upsertAccess(subscription.topic, everyone);
await accountApi.upsertReservation(subscription.topic, everyone);
} else if (!reserveTopicVisible && subscription.reservation) { // Was removed
await accountApi.deleteAccess(subscription.topic);
await accountApi.deleteReservation(subscription.topic);
}
// Sync account

View File

@ -2,28 +2,83 @@ import * as React from 'react';
import Dialog from '@mui/material/Dialog';
import DialogContent from '@mui/material/DialogContent';
import DialogTitle from '@mui/material/DialogTitle';
import {useMediaQuery} from "@mui/material";
import {CardActionArea, CardContent, useMediaQuery} from "@mui/material";
import theme from "./theme";
import DialogFooter from "./DialogFooter";
import Button from "@mui/material/Button";
import accountApi, {TopicReservedError, UnauthorizedError} from "../app/AccountApi";
import session from "../app/Session";
import routes from "./routes";
import {useContext, useState} from "react";
import Card from "@mui/material/Card";
import Typography from "@mui/material/Typography";
import {AccountContext} from "./App";
const UpgradeDialog = (props) => {
const { account } = useContext(AccountContext);
const fullScreen = useMediaQuery(theme.breakpoints.down('sm'));
const [selected, setSelected] = useState(account?.tier?.code || null);
const [errorText, setErrorText] = useState("");
const handleSuccess = async () => {
// TODO
const handleCheckout = async () => {
try {
const response = await accountApi.createCheckoutSession(selected);
if (response.redirect_url) {
window.location.href = response.redirect_url;
} else {
await accountApi.sync();
}
} catch (e) {
console.log(`[UpgradeDialog] Error creating checkout session`, e);
if ((e instanceof UnauthorizedError)) {
session.resetAndRedirect(routes.login);
}
// FIXME show error
}
}
return (
<Dialog open={props.open} onClose={props.onCancel} fullScreen={fullScreen}>
<Dialog open={props.open} onClose={props.onCancel} maxWidth="md" fullScreen={fullScreen}>
<DialogTitle>Upgrade to Pro</DialogTitle>
<DialogContent>
Content
<div style={{
display: "flex",
flexDirection: "row"
}}>
<TierCard code={null} name={"Free"} selected={selected === null} onClick={() => setSelected(null)}/>
<TierCard code="starter" name={"Starter"} selected={selected === "starter"} onClick={() => setSelected("starter")}/>
<TierCard code="pro" name={"Pro"} selected={selected === "pro"} onClick={() => setSelected("pro")}/>
<TierCard code="business" name={"Business"} selected={selected === "business"} onClick={() => setSelected("business")}/>
</div>
</DialogContent>
<DialogFooter>
Footer
<DialogFooter status={errorText}>
<Button onClick={handleCheckout}>Checkout</Button>
</DialogFooter>
</Dialog>
);
};
const TierCard = (props) => {
const cardStyle = (props.selected) ? {
border: "1px solid red",
} : {};
return (
<Card sx={{ m: 1, maxWidth: 345 }}>
<CardActionArea>
<CardContent sx={{...cardStyle}} onClick={props.onClick}>
<Typography gutterBottom variant="h5" component="div">
{props.name}
</Typography>
<Typography variant="body2" color="text.secondary">
Lizards are a widespread group of squamate reptiles, with over 6,000
species, ranging across all continents except Antarctica
</Typography>
</CardContent>
</CardActionArea>
</Card>
);
}
export default UpgradeDialog;