131 lines
4.3 KiB
Go
131 lines
4.3 KiB
Go
|
package server
|
||
|
|
||
|
import (
|
||
|
"github.com/stretchr/testify/mock"
|
||
|
"github.com/stretchr/testify/require"
|
||
|
"github.com/stripe/stripe-go/v74"
|
||
|
"heckel.io/ntfy/user"
|
||
|
"heckel.io/ntfy/util"
|
||
|
"io"
|
||
|
"testing"
|
||
|
)
|
||
|
|
||
|
func TestPayments_SubscriptionCreate_NotAStripeCustomer_Success(t *testing.T) {
|
||
|
stripeMock := &testStripeAPI{}
|
||
|
defer stripeMock.AssertExpectations(t)
|
||
|
|
||
|
c := newTestConfigWithAuthFile(t)
|
||
|
c.StripeSecretKey = "secret key"
|
||
|
c.StripeWebhookKey = "webhook key"
|
||
|
s := newTestServer(t, c)
|
||
|
s.stripe = stripeMock
|
||
|
|
||
|
// Define how the mock should react
|
||
|
stripeMock.
|
||
|
On("NewCheckoutSession", mock.Anything).
|
||
|
Return(&stripe.CheckoutSession{URL: "https://billing.stripe.com/abc/def"}, nil)
|
||
|
|
||
|
// Create tier and user
|
||
|
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
||
|
Code: "pro",
|
||
|
StripePriceID: "price_123",
|
||
|
}))
|
||
|
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
|
||
|
|
||
|
// Create subscription
|
||
|
response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro"}`, map[string]string{
|
||
|
"Authorization": util.BasicAuth("phil", "phil"),
|
||
|
})
|
||
|
require.Equal(t, 200, response.Code)
|
||
|
redirectResponse, err := util.UnmarshalJSON[apiAccountBillingSubscriptionCreateResponse](io.NopCloser(response.Body))
|
||
|
require.Nil(t, err)
|
||
|
require.Equal(t, "https://billing.stripe.com/abc/def", redirectResponse.RedirectURL)
|
||
|
}
|
||
|
|
||
|
func TestPayments_SubscriptionCreate_StripeCustomer_Success(t *testing.T) {
|
||
|
stripeMock := &testStripeAPI{}
|
||
|
defer stripeMock.AssertExpectations(t)
|
||
|
|
||
|
c := newTestConfigWithAuthFile(t)
|
||
|
c.StripeSecretKey = "secret key"
|
||
|
c.StripeWebhookKey = "webhook key"
|
||
|
s := newTestServer(t, c)
|
||
|
s.stripe = stripeMock
|
||
|
|
||
|
// Define how the mock should react
|
||
|
stripeMock.
|
||
|
On("GetCustomer", "acct_123").
|
||
|
Return(&stripe.Customer{Subscriptions: &stripe.SubscriptionList{}}, nil)
|
||
|
stripeMock.
|
||
|
On("NewCheckoutSession", mock.Anything).
|
||
|
Return(&stripe.CheckoutSession{URL: "https://billing.stripe.com/abc/def"}, nil)
|
||
|
|
||
|
// Create tier and user
|
||
|
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
||
|
Code: "pro",
|
||
|
StripePriceID: "price_123",
|
||
|
}))
|
||
|
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
|
||
|
|
||
|
u, err := s.userManager.User("phil")
|
||
|
require.Nil(t, err)
|
||
|
|
||
|
u.Billing.StripeCustomerID = "acct_123"
|
||
|
require.Nil(t, s.userManager.ChangeBilling(u))
|
||
|
|
||
|
// Create subscription
|
||
|
response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro"}`, map[string]string{
|
||
|
"Authorization": util.BasicAuth("phil", "phil"),
|
||
|
})
|
||
|
require.Equal(t, 200, response.Code)
|
||
|
redirectResponse, err := util.UnmarshalJSON[apiAccountBillingSubscriptionCreateResponse](io.NopCloser(response.Body))
|
||
|
require.Nil(t, err)
|
||
|
require.Equal(t, "https://billing.stripe.com/abc/def", redirectResponse.RedirectURL)
|
||
|
}
|
||
|
|
||
|
type testStripeAPI struct {
|
||
|
mock.Mock
|
||
|
}
|
||
|
|
||
|
func (s *testStripeAPI) NewCheckoutSession(params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error) {
|
||
|
args := s.Called(params)
|
||
|
return args.Get(0).(*stripe.CheckoutSession), args.Error(1)
|
||
|
}
|
||
|
|
||
|
func (s *testStripeAPI) NewPortalSession(params *stripe.BillingPortalSessionParams) (*stripe.BillingPortalSession, error) {
|
||
|
args := s.Called(params)
|
||
|
return args.Get(0).(*stripe.BillingPortalSession), args.Error(1)
|
||
|
}
|
||
|
|
||
|
func (s *testStripeAPI) ListPrices(params *stripe.PriceListParams) ([]*stripe.Price, error) {
|
||
|
args := s.Called(params)
|
||
|
return args.Get(0).([]*stripe.Price), args.Error(1)
|
||
|
}
|
||
|
|
||
|
func (s *testStripeAPI) GetCustomer(id string) (*stripe.Customer, error) {
|
||
|
args := s.Called(id)
|
||
|
return args.Get(0).(*stripe.Customer), args.Error(1)
|
||
|
}
|
||
|
|
||
|
func (s *testStripeAPI) GetSession(id string) (*stripe.CheckoutSession, error) {
|
||
|
args := s.Called(id)
|
||
|
return args.Get(0).(*stripe.CheckoutSession), args.Error(1)
|
||
|
}
|
||
|
|
||
|
func (s *testStripeAPI) GetSubscription(id string) (*stripe.Subscription, error) {
|
||
|
args := s.Called(id)
|
||
|
return args.Get(0).(*stripe.Subscription), args.Error(1)
|
||
|
}
|
||
|
|
||
|
func (s *testStripeAPI) UpdateSubscription(id string, params *stripe.SubscriptionParams) (*stripe.Subscription, error) {
|
||
|
args := s.Called(id)
|
||
|
return args.Get(0).(*stripe.Subscription), args.Error(1)
|
||
|
}
|
||
|
|
||
|
func (s *testStripeAPI) ConstructWebhookEvent(payload []byte, header string, secret string) (stripe.Event, error) {
|
||
|
args := s.Called(payload, header, secret)
|
||
|
return args.Get(0).(stripe.Event), args.Error(1)
|
||
|
}
|
||
|
|
||
|
var _ stripeAPI = (*testStripeAPI)(nil)
|