Add whitelist for PDS hosts and update repo PDS pointer on appropriate occasions

This commit is contained in:
Max Ignatenko 2024-02-22 12:05:44 +00:00
parent 600dac7694
commit 8f32c494f7
8 changed files with 108 additions and 43 deletions

View file

@ -3,6 +3,7 @@ package main
import (
"context"
"fmt"
"slices"
"time"
"github.com/rs/zerolog"
@ -101,7 +102,11 @@ func (s *Scheduler) fillQueue(ctx context.Context) error {
if err := s.db.Find(&remotes).Error; err != nil {
return fmt.Errorf("failed to get the list of PDSs: %w", err)
}
perPDSLimit := 0
remotes = slices.DeleteFunc(remotes, func(pds pds.PDS) bool {
return pds.Disabled
})
perPDSLimit := maxQueueLen
if len(remotes) > 0 {
perPDSLimit = maxQueueLen * 2 / len(remotes)
}

View file

@ -4,7 +4,6 @@ import (
"bytes"
"context"
"fmt"
"net/url"
"regexp"
"strings"
"time"
@ -143,27 +142,20 @@ func (p *WorkerPool) doWork(ctx context.Context, work WorkItem) error {
log := zerolog.Ctx(ctx)
defer close(work.signal)
doc, err := resolver.GetDocument(ctx, work.Repo.DID)
u, err := resolver.GetPDSEndpoint(ctx, work.Repo.DID)
if err != nil {
return fmt.Errorf("resolving did %q: %w", work.Repo.DID, err)
return err
}
pdsHost := ""
for _, srv := range doc.Service {
if srv.Type != "AtprotoPersonalDataServer" {
continue
}
pdsHost = srv.ServiceEndpoint
}
if pdsHost == "" {
return fmt.Errorf("did not find any PDS in DID Document")
}
u, err := url.Parse(pdsHost)
remote, err := pds.EnsureExists(ctx, p.db, u.String())
if err != nil {
return fmt.Errorf("PDS endpoint (%q) is an invalid URL: %w", pdsHost, err)
return fmt.Errorf("failed to get PDS records for %q: %w", u, err)
}
if u.Host == "" {
return fmt.Errorf("PDS endpoint (%q) doesn't have a host part", pdsHost)
if work.Repo.PDS != remote.ID {
if err := p.db.Model(&work.Repo).Where(&repo.Repo{ID: work.Repo.ID}).Updates(&repo.Repo{PDS: remote.ID}).Error; err != nil {
return fmt.Errorf("failed to update repo's PDS to %q: %w", u, err)
}
work.Repo.PDS = remote.ID
}
client := xrpcauth.NewAnonymousClient(ctx)