Add whitelist for PDS hosts and update repo PDS pointer on appropriate occasions

main
Max Ignatenko 2024-02-22 12:05:44 +00:00
parent 600dac7694
commit 8f32c494f7
8 changed files with 108 additions and 43 deletions

View File

@ -28,6 +28,7 @@ import (
"github.com/uabluerail/indexer/models" "github.com/uabluerail/indexer/models"
"github.com/uabluerail/indexer/pds" "github.com/uabluerail/indexer/pds"
"github.com/uabluerail/indexer/repo" "github.com/uabluerail/indexer/repo"
"github.com/uabluerail/indexer/util/resolver"
) )
type BadRecord struct { type BadRecord struct {
@ -245,11 +246,33 @@ func (c *Consumer) processMessage(ctx context.Context, typ string, r io.Reader,
if err != nil { if err != nil {
return fmt.Errorf("repo.EnsureExists(%q): %w", payload.Repo, err) return fmt.Errorf("repo.EnsureExists(%q): %w", payload.Repo, err)
} }
if repoInfo.PDS != models.ID(c.remote.ID) { if repoInfo.PDS != c.remote.ID {
u, err := resolver.GetPDSEndpoint(ctx, payload.Repo)
if err == nil {
cur, err := pds.EnsureExists(ctx, c.db, u.String())
if err == nil {
if repoInfo.PDS != cur.ID {
// Repo was migrated, lets update our record.
err := c.db.Model(repoInfo).Where(&repo.Repo{ID: repoInfo.ID}).Updates(&repo.Repo{PDS: cur.ID}).Error
if err != nil {
log.Error().Err(err).Msgf("Repo %q was migrated to %q, but updating the repo has failed: %s", payload.Repo, cur.Host, err)
}
}
repoInfo.PDS = cur.ID
} else {
log.Error().Err(err).Msgf("Failed to get PDS record for %q: %s", u, err)
}
} else {
log.Error().Err(err).Msgf("Failed to get PDS endpoint for repo %q: %s", payload.Repo, err)
}
if repoInfo.PDS != c.remote.ID {
// We checked a recent version of DID doc and this is still not a correct PDS.
log.Error().Str("did", payload.Repo).Str("rev", payload.Rev). log.Error().Str("did", payload.Repo).Str("rev", payload.Rev).
Msgf("Commit from an incorrect PDS, skipping") Msgf("Commit from an incorrect PDS, skipping")
return nil return nil
} }
}
if created { if created {
reposDiscovered.WithLabelValues(c.remote.Host).Inc() reposDiscovered.WithLabelValues(c.remote.Host).Inc()
} }
@ -436,7 +459,11 @@ func (c *Consumer) processMessage(ctx context.Context, typ string, r io.Reader,
log.Error().Msgf("Unknown #info message %q: %+v", payload.Name, payload) log.Error().Msgf("Unknown #info message %q: %+v", payload.Name, payload)
} }
default: default:
log.Warn().Msgf("Unknown message type received: %s", typ) b, err := io.ReadAll(r)
if err != nil {
log.Error().Err(err).Msgf("Failed to read message payload: %s", err)
}
log.Warn().Msgf("Unknown message type received: %s payload=%q", typ, string(b))
} }
return nil return nil
} }

View File

@ -60,6 +60,9 @@ func runMain(ctx context.Context) error {
} }
// TODO: check for changes and start/stop consumers as needed // TODO: check for changes and start/stop consumers as needed
for _, remote := range remotes { for _, remote := range remotes {
if remote.Disabled {
continue
}
c, err := NewConsumer(ctx, &remote, db) c, err := NewConsumer(ctx, &remote, db)
if err != nil { if err != nil {
return fmt.Errorf("failed to create a consumer for %q: %w", remote.Host, err) return fmt.Errorf("failed to create a consumer for %q: %w", remote.Host, err)

View File

@ -57,13 +57,22 @@ func (l *Lister) run(ctx context.Context) {
remote := pds.PDS{} remote := pds.PDS{}
if err := db.Model(&remote). if err := db.Model(&remote).
Where("last_list is null or last_list < ?", time.Now().Add(-l.listRefreshInterval)). Where("disabled=false and (last_list is null or last_list < ?)", time.Now().Add(-l.listRefreshInterval)).
Take(&remote).Error; err != nil { Take(&remote).Error; err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) { if !errors.Is(err, gorm.ErrRecordNotFound) {
log.Error().Err(err).Msgf("Failed to query DB for a PDS to list repos from: %s", err) log.Error().Err(err).Msgf("Failed to query DB for a PDS to list repos from: %s", err)
} }
break break
} }
if !pds.IsWhitelisted(remote.Host) {
log.Info().Msgf("PDS %q is not whitelisted, disabling it", remote.Host)
if err := db.Model(&remote).Where(&pds.PDS{ID: remote.ID}).Updates(&pds.PDS{Disabled: true}).Error; err != nil {
log.Error().Err(err).Msgf("Failed to disable PDS %q: %s", remote.Host, err)
}
break
}
client := xrpcauth.NewAnonymousClient(ctx) client := xrpcauth.NewAnonymousClient(ctx)
client.Host = remote.Host client.Host = remote.Host

View File

@ -3,6 +3,7 @@ package main
import ( import (
"context" "context"
"fmt" "fmt"
"slices"
"time" "time"
"github.com/rs/zerolog" "github.com/rs/zerolog"
@ -101,7 +102,11 @@ func (s *Scheduler) fillQueue(ctx context.Context) error {
if err := s.db.Find(&remotes).Error; err != nil { if err := s.db.Find(&remotes).Error; err != nil {
return fmt.Errorf("failed to get the list of PDSs: %w", err) return fmt.Errorf("failed to get the list of PDSs: %w", err)
} }
perPDSLimit := 0
remotes = slices.DeleteFunc(remotes, func(pds pds.PDS) bool {
return pds.Disabled
})
perPDSLimit := maxQueueLen
if len(remotes) > 0 { if len(remotes) > 0 {
perPDSLimit = maxQueueLen * 2 / len(remotes) perPDSLimit = maxQueueLen * 2 / len(remotes)
} }

View File

@ -4,7 +4,6 @@ import (
"bytes" "bytes"
"context" "context"
"fmt" "fmt"
"net/url"
"regexp" "regexp"
"strings" "strings"
"time" "time"
@ -143,27 +142,20 @@ func (p *WorkerPool) doWork(ctx context.Context, work WorkItem) error {
log := zerolog.Ctx(ctx) log := zerolog.Ctx(ctx)
defer close(work.signal) defer close(work.signal)
doc, err := resolver.GetDocument(ctx, work.Repo.DID) u, err := resolver.GetPDSEndpoint(ctx, work.Repo.DID)
if err != nil { if err != nil {
return fmt.Errorf("resolving did %q: %w", work.Repo.DID, err) return err
} }
pdsHost := "" remote, err := pds.EnsureExists(ctx, p.db, u.String())
for _, srv := range doc.Service {
if srv.Type != "AtprotoPersonalDataServer" {
continue
}
pdsHost = srv.ServiceEndpoint
}
if pdsHost == "" {
return fmt.Errorf("did not find any PDS in DID Document")
}
u, err := url.Parse(pdsHost)
if err != nil { if err != nil {
return fmt.Errorf("PDS endpoint (%q) is an invalid URL: %w", pdsHost, err) return fmt.Errorf("failed to get PDS records for %q: %w", u, err)
} }
if u.Host == "" { if work.Repo.PDS != remote.ID {
return fmt.Errorf("PDS endpoint (%q) doesn't have a host part", pdsHost) if err := p.db.Model(&work.Repo).Where(&repo.Repo{ID: work.Repo.ID}).Updates(&repo.Repo{PDS: remote.ID}).Error; err != nil {
return fmt.Errorf("failed to update repo's PDS to %q: %w", u, err)
}
work.Repo.PDS = remote.ID
} }
client := xrpcauth.NewAnonymousClient(ctx) client := xrpcauth.NewAnonymousClient(ctx)

View File

@ -3,6 +3,7 @@ package pds
import ( import (
"context" "context"
"fmt" "fmt"
"path/filepath"
"time" "time"
"gorm.io/gorm" "gorm.io/gorm"
@ -12,6 +13,11 @@ import (
const Unknown models.ID = 0 const Unknown models.ID = 0
var whitelist []string = []string{
"https://bsky.social",
"https://*.bsky.network",
}
type PDS struct { type PDS struct {
ID models.ID `gorm:"primarykey"` ID models.ID `gorm:"primarykey"`
CreatedAt time.Time CreatedAt time.Time
@ -21,6 +27,7 @@ type PDS struct {
FirstCursorSinceReset int64 FirstCursorSinceReset int64
LastList time.Time LastList time.Time
CrawlLimit int CrawlLimit int
Disabled bool
} }
func AutoMigrate(db *gorm.DB) error { func AutoMigrate(db *gorm.DB) error {
@ -28,9 +35,21 @@ func AutoMigrate(db *gorm.DB) error {
} }
func EnsureExists(ctx context.Context, db *gorm.DB, host string) (*PDS, error) { func EnsureExists(ctx context.Context, db *gorm.DB, host string) (*PDS, error) {
if !IsWhitelisted(host) {
return nil, fmt.Errorf("host %q is not whitelisted", host)
}
remote := PDS{Host: host} remote := PDS{Host: host}
if err := db.Model(&remote).Where(&PDS{Host: host}).FirstOrCreate(&remote).Error; err != nil { if err := db.Model(&remote).Where(&PDS{Host: host}).FirstOrCreate(&remote).Error; err != nil {
return nil, fmt.Errorf("failed to get PDS record from DB for %q: %w", remote.Host, err) return nil, fmt.Errorf("failed to get PDS record from DB for %q: %w", remote.Host, err)
} }
return &remote, nil return &remote, nil
} }
func IsWhitelisted(host string) bool {
for _, p := range whitelist {
if match, _ := filepath.Match(p, host); match {
return true
}
}
return false
}

View File

@ -5,7 +5,6 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"net/url"
"time" "time"
"gorm.io/gorm" "gorm.io/gorm"
@ -66,28 +65,11 @@ func EnsureExists(ctx context.Context, db *gorm.DB, did string) (*Repo, bool, er
// if we do - compare PDS IDs // if we do - compare PDS IDs
// if they don't match - also reset FirstRevSinceReset // if they don't match - also reset FirstRevSinceReset
doc, err := resolver.GetDocument(ctx, did) u, err := resolver.GetPDSEndpoint(ctx, did)
if err != nil { if err != nil {
return nil, false, fmt.Errorf("fetching DID Document: %w", err) return nil, false, fmt.Errorf("fetching DID Document: %w", err)
} }
pdsHost := ""
for _, srv := range doc.Service {
if srv.Type != "AtprotoPersonalDataServer" {
continue
}
pdsHost = srv.ServiceEndpoint
}
if pdsHost == "" {
return nil, false, fmt.Errorf("did not find any PDS in DID Document")
}
u, err := url.Parse(pdsHost)
if err != nil {
return nil, false, fmt.Errorf("PDS endpoint (%q) is an invalid URL: %w", pdsHost, err)
}
if u.Host == "" {
return nil, false, fmt.Errorf("PDS endpoint (%q) doesn't have a host part", pdsHost)
}
remote, err := pds.EnsureExists(ctx, db, u.String()) remote, err := pds.EnsureExists(ctx, db, u.String())
if err != nil { if err != nil {
return nil, false, fmt.Errorf("failed to get PDS record from DB for %q: %w", remote.Host, err) return nil, false, fmt.Errorf("failed to get PDS record from DB for %q: %w", remote.Host, err)

View File

@ -3,6 +3,8 @@ package resolver
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"net/url"
"os" "os"
"github.com/bluesky-social/indigo/api" "github.com/bluesky-social/indigo/api"
@ -56,3 +58,29 @@ func (r *fallbackResolver) FlushCacheFor(did string) {
res.FlushCacheFor(did) res.FlushCacheFor(did)
} }
} }
func GetPDSEndpoint(ctx context.Context, did string) (*url.URL, error) {
doc, err := GetDocument(ctx, did)
if err != nil {
return nil, fmt.Errorf("resolving did %q: %w", did, err)
}
pdsHost := ""
for _, srv := range doc.Service {
if srv.Type != "AtprotoPersonalDataServer" {
continue
}
pdsHost = srv.ServiceEndpoint
}
if pdsHost == "" {
return nil, fmt.Errorf("did not find any PDS in DID Document")
}
u, err := url.Parse(pdsHost)
if err != nil {
return nil, fmt.Errorf("PDS endpoint (%q) is an invalid URL: %w", pdsHost, err)
}
if u.Host == "" {
return nil, fmt.Errorf("PDS endpoint (%q) doesn't have a host part", pdsHost)
}
return u, nil
}