plc-mirror/cmd/record-indexer/workerpool.go

239 lines
5.9 KiB
Go

package main
import (
"bytes"
"context"
"fmt"
"net/url"
"strings"
"time"
"github.com/imax9000/errors"
"github.com/rs/zerolog"
"gorm.io/gorm"
"gorm.io/gorm/clause"
comatproto "github.com/bluesky-social/indigo/api/atproto"
"github.com/bluesky-social/indigo/xrpc"
"github.com/uabluerail/bsky-tools/xrpcauth"
"github.com/uabluerail/indexer/models"
"github.com/uabluerail/indexer/repo"
"github.com/uabluerail/indexer/util/resolver"
)
type WorkItem struct {
Repo *repo.Repo
signal chan struct{}
}
type WorkerPool struct {
db *gorm.DB
input <-chan WorkItem
limiter *Limiter
workerSignals []chan struct{}
resize chan int
}
func NewWorkerPool(input <-chan WorkItem, db *gorm.DB, size int, limiter *Limiter) *WorkerPool {
r := &WorkerPool{
db: db,
input: input,
limiter: limiter,
resize: make(chan int),
}
r.workerSignals = make([]chan struct{}, size)
for i := range r.workerSignals {
r.workerSignals[i] = make(chan struct{})
}
return r
}
func (p *WorkerPool) Start(ctx context.Context) error {
go p.run(ctx)
return nil
}
func (p *WorkerPool) Resize(ctx context.Context, size int) error {
select {
case <-ctx.Done():
return ctx.Err()
case p.resize <- size:
return nil
}
}
func (p *WorkerPool) run(ctx context.Context) {
for _, ch := range p.workerSignals {
go p.worker(ctx, ch)
}
workerPoolSize.Set(float64(len(p.workerSignals)))
for {
select {
case <-ctx.Done():
for _, ch := range p.workerSignals {
close(ch)
}
// also wait for all workers to stop?
case newSize := <-p.resize:
switch {
case newSize > len(p.workerSignals):
ch := make([]chan struct{}, newSize-len(p.workerSignals))
for i := range ch {
ch[i] = make(chan struct{})
go p.worker(ctx, ch[i])
}
p.workerSignals = append(p.workerSignals, ch...)
workerPoolSize.Set(float64(len(p.workerSignals)))
case newSize < len(p.workerSignals) && newSize > 0:
for _, ch := range p.workerSignals[newSize:] {
close(ch)
}
p.workerSignals = p.workerSignals[:newSize]
workerPoolSize.Set(float64(len(p.workerSignals)))
}
}
}
}
func (p *WorkerPool) worker(ctx context.Context, signal chan struct{}) {
log := zerolog.Ctx(ctx)
for {
select {
case <-ctx.Done():
return
case <-signal:
return
case work := <-p.input:
updates := &repo.Repo{}
if err := p.doWork(ctx, work); err != nil {
log.Error().Err(err).Msgf("Work task %q failed: %s", work.Repo.DID, err)
updates.LastError = err.Error()
reposIndexed.WithLabelValues("false").Inc()
} else {
reposIndexed.WithLabelValues("true").Inc()
}
updates.LastIndexAttempt = time.Now()
err := p.db.Model(&repo.Repo{}).
Where(&repo.Repo{Model: gorm.Model{ID: work.Repo.ID}}).
Updates(updates).Error
if err != nil {
log.Error().Err(err).Msgf("Failed to update repo info for %q: %s", work.Repo.DID, err)
}
}
}
}
func (p *WorkerPool) doWork(ctx context.Context, work WorkItem) error {
log := zerolog.Ctx(ctx)
defer close(work.signal)
doc, err := resolver.GetDocument(ctx, work.Repo.DID)
if err != nil {
return fmt.Errorf("resolving did %q: %w", work.Repo.DID, err)
}
pdsHost := ""
for _, srv := range doc.Service {
if srv.Type != "AtprotoPersonalDataServer" {
continue
}
pdsHost = srv.ServiceEndpoint
}
if pdsHost == "" {
return fmt.Errorf("did not find any PDS in DID Document")
}
u, err := url.Parse(pdsHost)
if err != nil {
return fmt.Errorf("PDS endpoint (%q) is an invalid URL: %w", pdsHost, err)
}
if u.Host == "" {
return fmt.Errorf("PDS endpoint (%q) doesn't have a host part", pdsHost)
}
client := xrpcauth.NewAnonymousClient(ctx)
client.Host = u.String()
retry:
if p.limiter != nil {
if err := p.limiter.Wait(ctx, u.String()); err != nil {
return fmt.Errorf("failed to wait on rate limiter: %w", err)
}
}
b, err := comatproto.SyncGetRepo(ctx, client, work.Repo.DID, "")
if err != nil {
if err, ok := errors.As[*xrpc.Error](err); ok {
if err.IsThrottled() && err.Ratelimit != nil {
log.Debug().Str("pds", u.String()).Msgf("Hit a rate limit (%s), sleeping until %s", err.Ratelimit.Policy, err.Ratelimit.Reset)
time.Sleep(time.Until(err.Ratelimit.Reset))
goto retry
}
}
reposFetched.WithLabelValues(u.String(), "false").Inc()
return fmt.Errorf("failed to fetch repo: %w", err)
}
reposFetched.WithLabelValues(u.String(), "true").Inc()
newRev, err := repo.GetRev(ctx, bytes.NewReader(b))
if err != nil {
return fmt.Errorf("failed to read 'rev' from the fetched repo: %w", err)
}
newRecs, err := repo.ExtractRecords(ctx, bytes.NewReader(b))
if err != nil {
return fmt.Errorf("failed to extract records: %w", err)
}
recs := []repo.Record{}
for k, v := range newRecs {
parts := strings.SplitN(k, "/", 2)
if len(parts) != 2 {
log.Warn().Msgf("Unexpected key format: %q", k)
continue
}
recs = append(recs, repo.Record{
Repo: models.ID(work.Repo.ID),
Collection: parts[0],
Rkey: parts[1],
Content: v,
})
}
recordsFetched.Add(float64(len(recs)))
if len(recs) > 0 {
for _, batch := range splitInBatshes(recs, 500) {
result := p.db.Model(&repo.Record{}).
Clauses(clause.OnConflict{DoUpdates: clause.AssignmentColumns([]string{"content"}),
Columns: []clause.Column{{Name: "repo"}, {Name: "collection"}, {Name: "rkey"}}}).
Create(batch)
if err := result.Error; err != nil {
return fmt.Errorf("inserting records into the database: %w", err)
}
recordsInserted.Add(float64(result.RowsAffected))
}
}
err = p.db.Model(&repo.Repo{}).Where(&repo.Repo{Model: gorm.Model{ID: work.Repo.ID}}).
Updates(&repo.Repo{LastIndexedRev: newRev}).Error
if err != nil {
return fmt.Errorf("updating repo rev: %w", err)
}
return nil
}
func splitInBatshes[T any](s []T, batchSize int) [][]T {
var r [][]T
for i := 0; i < len(s); i += batchSize {
if i+batchSize < len(s) {
r = append(r, s[i:i+batchSize])
} else {
r = append(r, s[i:])
}
}
return r
}