ntfy/server/server_payments_test.go

338 lines
11 KiB
Go

package server
import (
"encoding/json"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/stripe/stripe-go/v74"
"heckel.io/ntfy/user"
"heckel.io/ntfy/util"
"io"
"path/filepath"
"strings"
"testing"
"time"
)
func TestPayments_SubscriptionCreate_NotAStripeCustomer_Success(t *testing.T) {
stripeMock := &testStripeAPI{}
defer stripeMock.AssertExpectations(t)
c := newTestConfigWithAuthFile(t)
c.StripeSecretKey = "secret key"
c.StripeWebhookKey = "webhook key"
s := newTestServer(t, c)
s.stripe = stripeMock
// Define how the mock should react
stripeMock.
On("NewCheckoutSession", mock.Anything).
Return(&stripe.CheckoutSession{URL: "https://billing.stripe.com/abc/def"}, nil)
// Create tier and user
require.Nil(t, s.userManager.CreateTier(&user.Tier{
Code: "pro",
StripePriceID: "price_123",
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
// Create subscription
response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, response.Code)
redirectResponse, err := util.UnmarshalJSON[apiAccountBillingSubscriptionCreateResponse](io.NopCloser(response.Body))
require.Nil(t, err)
require.Equal(t, "https://billing.stripe.com/abc/def", redirectResponse.RedirectURL)
}
func TestPayments_SubscriptionCreate_StripeCustomer_Success(t *testing.T) {
stripeMock := &testStripeAPI{}
defer stripeMock.AssertExpectations(t)
c := newTestConfigWithAuthFile(t)
c.StripeSecretKey = "secret key"
c.StripeWebhookKey = "webhook key"
s := newTestServer(t, c)
s.stripe = stripeMock
// Define how the mock should react
stripeMock.
On("GetCustomer", "acct_123").
Return(&stripe.Customer{Subscriptions: &stripe.SubscriptionList{}}, nil)
stripeMock.
On("NewCheckoutSession", mock.Anything).
Return(&stripe.CheckoutSession{URL: "https://billing.stripe.com/abc/def"}, nil)
// Create tier and user
require.Nil(t, s.userManager.CreateTier(&user.Tier{
Code: "pro",
StripePriceID: "price_123",
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
u, err := s.userManager.User("phil")
require.Nil(t, err)
billing := &user.Billing{
StripeCustomerID: "acct_123",
}
require.Nil(t, s.userManager.ChangeBilling(u.Name, billing))
// Create subscription
response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, response.Code)
redirectResponse, err := util.UnmarshalJSON[apiAccountBillingSubscriptionCreateResponse](io.NopCloser(response.Body))
require.Nil(t, err)
require.Equal(t, "https://billing.stripe.com/abc/def", redirectResponse.RedirectURL)
}
func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) {
stripeMock := &testStripeAPI{}
defer stripeMock.AssertExpectations(t)
c := newTestConfigWithAuthFile(t)
c.EnableSignup = true
c.StripeSecretKey = "secret key"
c.StripeWebhookKey = "webhook key"
s := newTestServer(t, c)
s.stripe = stripeMock
// Define how the mock should react
stripeMock.
On("CancelSubscription", "sub_123").
Return(&stripe.Subscription{}, nil)
// Create tier and user
require.Nil(t, s.userManager.CreateTier(&user.Tier{
Code: "pro",
StripePriceID: "price_123",
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
u, err := s.userManager.User("phil")
require.Nil(t, err)
billing := &user.Billing{
StripeCustomerID: "acct_123",
StripeSubscriptionID: "sub_123",
}
require.Nil(t, s.userManager.ChangeBilling(u.Name, billing))
// Delete account
rr := request(t, s, "DELETE", "/v1/account", "", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
rr = request(t, s, "GET", "/v1/account", "", map[string]string{
"Authorization": util.BasicAuth("phil", "mypass"),
})
require.Equal(t, 401, rr.Code)
}
func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(t *testing.T) {
// This tests incoming webhooks from Stripe to update a subscription:
// - All Stripe columns are updated in the user table
// - When downgrading, excess reservations are deleted, including messages and attachments in
// the corresponding topics
stripeMock := &testStripeAPI{}
defer stripeMock.AssertExpectations(t)
c := newTestConfigWithAuthFile(t)
c.StripeSecretKey = "secret key"
c.StripeWebhookKey = "webhook key"
s := newTestServer(t, c)
s.stripe = stripeMock
// Define how the mock should react
stripeMock.
On("ConstructWebhookEvent", mock.Anything, "stripe signature", "webhook key").
Return(jsonToStripeEvent(t, subscriptionUpdatedEventJSON), nil)
// Create a user with a Stripe subscription and 3 reservations
require.Nil(t, s.userManager.CreateTier(&user.Tier{
Code: "starter",
StripePriceID: "price_1234", // !
ReservationsLimit: 1, // !
MessagesLimit: 100,
MessagesExpiryDuration: time.Hour,
AttachmentExpiryDuration: time.Hour,
AttachmentFileSizeLimit: 1000000,
AttachmentTotalSizeLimit: 1000000,
}))
require.Nil(t, s.userManager.CreateTier(&user.Tier{
Code: "pro",
StripePriceID: "price_1111", // !
ReservationsLimit: 3, // !
MessagesLimit: 200,
MessagesExpiryDuration: time.Hour,
AttachmentExpiryDuration: time.Hour,
AttachmentFileSizeLimit: 1000000,
AttachmentTotalSizeLimit: 1000000,
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
require.Nil(t, s.userManager.ReserveAccess("phil", "atopic", user.PermissionDenyAll))
require.Nil(t, s.userManager.ReserveAccess("phil", "ztopic", user.PermissionDenyAll))
// Add billing details
u, err := s.userManager.User("phil")
require.Nil(t, err)
billing := &user.Billing{
StripeCustomerID: "acct_5555",
StripeSubscriptionID: "sub_1234",
StripeSubscriptionStatus: stripe.SubscriptionStatusPastDue,
StripeSubscriptionPaidUntil: time.Unix(123, 0),
StripeSubscriptionCancelAt: time.Unix(456, 0),
}
require.Nil(t, s.userManager.ChangeBilling(u.Name, billing))
// Add some messages to "atopic" and "ztopic", everything in "ztopic" will be deleted
rr := request(t, s, "PUT", "/atopic", "some aaa message", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
rr = request(t, s, "PUT", "/atopic", strings.Repeat("a", 5000), map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
a2 := toMessage(t, rr.Body.String())
require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, a2.ID))
rr = request(t, s, "PUT", "/ztopic", "some zzz message", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
rr = request(t, s, "PUT", "/ztopic", strings.Repeat("z", 5000), map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
z2 := toMessage(t, rr.Body.String())
require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, z2.ID))
// Call the webhook: This does all the magic
rr = request(t, s, "POST", "/v1/account/billing/webhook", "dummy", map[string]string{
"Stripe-Signature": "stripe signature",
})
require.Equal(t, 200, rr.Code)
// Verify that database columns were updated
u, err = s.userManager.User("phil")
require.Nil(t, err)
require.Equal(t, "starter", u.Tier.Code) // Not "pro"
require.Equal(t, "acct_5555", u.Billing.StripeCustomerID)
require.Equal(t, "sub_1234", u.Billing.StripeSubscriptionID)
require.Equal(t, stripe.SubscriptionStatusActive, u.Billing.StripeSubscriptionStatus) // Not "past_due"
require.Equal(t, int64(1674268231), u.Billing.StripeSubscriptionPaidUntil.Unix()) // Updated
require.Equal(t, int64(1674299999), u.Billing.StripeSubscriptionCancelAt.Unix()) // Updated
// Verify that reservations were deleted
r, err := s.userManager.Reservations("phil")
require.Nil(t, err)
require.Equal(t, 1, len(r)) // "ztopic" reservation was deleted
require.Equal(t, "atopic", r[0].Topic)
// Verify that messages and attachments were deleted
time.Sleep(time.Second)
s.execManager()
ms, err := s.messageCache.Messages("atopic", sinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 2, len(ms))
require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, a2.ID))
ms, err = s.messageCache.Messages("ztopic", sinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 0, len(ms))
require.NoFileExists(t, filepath.Join(s.config.AttachmentCacheDir, z2.ID))
}
type testStripeAPI struct {
mock.Mock
}
func (s *testStripeAPI) NewCheckoutSession(params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error) {
args := s.Called(params)
return args.Get(0).(*stripe.CheckoutSession), args.Error(1)
}
func (s *testStripeAPI) NewPortalSession(params *stripe.BillingPortalSessionParams) (*stripe.BillingPortalSession, error) {
args := s.Called(params)
return args.Get(0).(*stripe.BillingPortalSession), args.Error(1)
}
func (s *testStripeAPI) ListPrices(params *stripe.PriceListParams) ([]*stripe.Price, error) {
args := s.Called(params)
return args.Get(0).([]*stripe.Price), args.Error(1)
}
func (s *testStripeAPI) GetCustomer(id string) (*stripe.Customer, error) {
args := s.Called(id)
return args.Get(0).(*stripe.Customer), args.Error(1)
}
func (s *testStripeAPI) GetSession(id string) (*stripe.CheckoutSession, error) {
args := s.Called(id)
return args.Get(0).(*stripe.CheckoutSession), args.Error(1)
}
func (s *testStripeAPI) GetSubscription(id string) (*stripe.Subscription, error) {
args := s.Called(id)
return args.Get(0).(*stripe.Subscription), args.Error(1)
}
func (s *testStripeAPI) UpdateSubscription(id string, params *stripe.SubscriptionParams) (*stripe.Subscription, error) {
args := s.Called(id)
return args.Get(0).(*stripe.Subscription), args.Error(1)
}
func (s *testStripeAPI) CancelSubscription(id string) (*stripe.Subscription, error) {
args := s.Called(id)
return args.Get(0).(*stripe.Subscription), args.Error(1)
}
func (s *testStripeAPI) ConstructWebhookEvent(payload []byte, header string, secret string) (stripe.Event, error) {
args := s.Called(payload, header, secret)
return args.Get(0).(stripe.Event), args.Error(1)
}
var _ stripeAPI = (*testStripeAPI)(nil)
func jsonToStripeEvent(t *testing.T, v string) stripe.Event {
var e stripe.Event
if err := json.Unmarshal([]byte(v), &e); err != nil {
t.Fatal(err)
}
return e
}
const subscriptionUpdatedEventJSON = `
{
"type": "customer.subscription.updated",
"data": {
"object": {
"id": "sub_1234",
"customer": "acct_5555",
"status": "active",
"current_period_end": 1674268231,
"cancel_at": 1674299999,
"items": {
"data": [
{
"price": {
"id": "price_1234"
}
}
]
}
}
}
}`