This commit is contained in:
Max Ignatenko 2024-02-15 16:10:39 +00:00
parent 2b6abac607
commit 63a767d890
25 changed files with 3027 additions and 0 deletions

View file

@ -0,0 +1,14 @@
FROM golang:1.21.1 as builder
WORKDIR /app
COPY go.mod go.sum ./
RUN go mod download
COPY . ./
RUN go build -trimpath ./cmd/record-indexer
FROM alpine:latest as certs
RUN apk --update add ca-certificates
FROM debian:stable-slim
COPY --from=certs /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ca-certificates.crt
COPY --from=builder /app/record-indexer .
ENTRYPOINT ["./record-indexer"]

View file

@ -0,0 +1,77 @@
package main
import (
"context"
"fmt"
"net/http"
"strconv"
"golang.org/x/time/rate"
)
func AddAdminHandlers(limiter *Limiter, pool *WorkerPool) {
http.HandleFunc("/rate/set", handleRateSet(limiter))
http.HandleFunc("/rate/setAll", handleRateSetAll(limiter))
http.HandleFunc("/pool/resize", handlePoolResize(pool))
}
func handlePoolResize(pool *WorkerPool) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
s := r.FormValue("size")
if s == "" {
http.Error(w, "need size", http.StatusBadRequest)
return
}
size, err := strconv.Atoi(s)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
pool.Resize(context.Background(), size)
fmt.Fprintln(w, "OK")
}
}
func handleRateSet(limiter *Limiter) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
s := r.FormValue("limit")
if s == "" {
http.Error(w, "need limit", http.StatusBadRequest)
return
}
name := r.FormValue("name")
if name == "" {
http.Error(w, "need name", http.StatusBadRequest)
return
}
limit, err := strconv.Atoi(s)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
limiter.SetLimit(context.Background(), name, rate.Limit(limit))
fmt.Fprintln(w, "OK")
}
}
func handleRateSetAll(limiter *Limiter) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
s := r.FormValue("limit")
if s == "" {
http.Error(w, "need limit", http.StatusBadRequest)
return
}
limit, err := strconv.Atoi(s)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
limiter.SetAllLimits(context.Background(), rate.Limit(limit))
fmt.Fprintln(w, "OK")
}
}

179
cmd/record-indexer/main.go Normal file
View file

@ -0,0 +1,179 @@
package main
import (
"context"
"flag"
"fmt"
"io"
"log"
"net/http"
_ "net/http/pprof"
"os"
"os/signal"
"path/filepath"
"runtime"
"runtime/debug"
"strings"
"syscall"
"time"
_ "github.com/joho/godotenv/autoload"
"github.com/kelseyhightower/envconfig"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/rs/zerolog"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"github.com/uabluerail/indexer/util/gormzerolog"
)
type Config struct {
LogFile string
LogFormat string `default:"text"`
LogLevel int64 `default:"1"`
MetricsPort string `split_words:"true"`
DBUrl string `envconfig:"POSTGRES_URL"`
Workers int `default:"2"`
}
var config Config
func runMain(ctx context.Context) error {
ctx = setupLogging(ctx)
log := zerolog.Ctx(ctx)
log.Debug().Msgf("Starting up...")
db, err := gorm.Open(postgres.Open(config.DBUrl), &gorm.Config{
Logger: gormzerolog.New(&logger.Config{
SlowThreshold: 3 * time.Second,
IgnoreRecordNotFoundError: true,
}, nil),
})
if err != nil {
return fmt.Errorf("connecting to the database: %w", err)
}
log.Debug().Msgf("DB connection established")
limiter, err := NewLimiter(db)
if err != nil {
return fmt.Errorf("failed to create limiter: %w", err)
}
ch := make(chan WorkItem)
pool := NewWorkerPool(ch, db, config.Workers, limiter)
if err := pool.Start(ctx); err != nil {
return fmt.Errorf("failed to start worker pool: %w", err)
}
scheduler := NewScheduler(ch, db)
if err := scheduler.Start(ctx); err != nil {
return fmt.Errorf("failed to start scheduler: %w", err)
}
log.Info().Msgf("Starting HTTP listener on %q...", config.MetricsPort)
AddAdminHandlers(limiter, pool)
http.Handle("/metrics", promhttp.Handler())
srv := &http.Server{Addr: fmt.Sprintf(":%s", config.MetricsPort)}
errCh := make(chan error)
go func() {
errCh <- srv.ListenAndServe()
}()
select {
case <-ctx.Done():
if err := srv.Shutdown(context.Background()); err != nil {
return fmt.Errorf("HTTP server shutdown failed: %w", err)
}
}
return <-errCh
}
func main() {
flag.StringVar(&config.LogFile, "log", "", "Path to the log file. If empty, will log to stderr")
flag.StringVar(&config.LogFormat, "log-format", "text", "Logging format. 'text' or 'json'")
flag.Int64Var(&config.LogLevel, "log-level", 1, "Log level. -1 - trace, 0 - debug, 1 - info, 5 - panic")
flag.IntVar(&config.Workers, "workers", 2, "Number of workers to start with")
if err := envconfig.Process("indexer", &config); err != nil {
log.Fatalf("envconfig.Process: %s", err)
}
flag.Parse()
ctx, _ := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
if err := runMain(ctx); err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
}
func setupLogging(ctx context.Context) context.Context {
logFile := os.Stderr
if config.LogFile != "" {
f, err := os.OpenFile(config.LogFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
log.Fatalf("Failed to open the specified log file %q: %s", config.LogFile, err)
}
logFile = f
}
var output io.Writer
switch config.LogFormat {
case "json":
output = logFile
case "text":
prefixList := []string{}
info, ok := debug.ReadBuildInfo()
if ok {
prefixList = append(prefixList, info.Path+"/")
}
basedir := ""
_, sourceFile, _, ok := runtime.Caller(0)
if ok {
basedir = filepath.Dir(sourceFile)
}
if basedir != "" && strings.HasPrefix(basedir, "/") {
prefixList = append(prefixList, basedir+"/")
head, _ := filepath.Split(basedir)
for head != "/" {
prefixList = append(prefixList, head)
head, _ = filepath.Split(strings.TrimSuffix(head, "/"))
}
}
output = zerolog.ConsoleWriter{
Out: logFile,
NoColor: true,
TimeFormat: time.RFC3339,
PartsOrder: []string{
zerolog.LevelFieldName,
zerolog.TimestampFieldName,
zerolog.CallerFieldName,
zerolog.MessageFieldName,
},
FormatFieldName: func(i interface{}) string { return fmt.Sprintf("%s:", i) },
FormatFieldValue: func(i interface{}) string { return fmt.Sprintf("%s", i) },
FormatCaller: func(i interface{}) string {
s := i.(string)
for _, p := range prefixList {
s = strings.TrimPrefix(s, p)
}
return s
},
}
default:
log.Fatalf("Invalid log format specified: %q", config.LogFormat)
}
logger := zerolog.New(output).Level(zerolog.Level(config.LogLevel)).With().Caller().Timestamp().Logger()
ctx = logger.WithContext(ctx)
zerolog.DefaultContextLogger = &logger
log.SetOutput(logger)
return ctx
}

View file

@ -0,0 +1,41 @@
package main
import (
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
)
var reposQueued = promauto.NewCounter(prometheus.CounterOpts{
Name: "indexer_repos_queued_count",
Help: "Number of repos added to the queue",
})
var queueLenght = promauto.NewGaugeVec(prometheus.GaugeOpts{
Name: "indexer_queue_length",
Help: "Current length of indexing queue",
}, []string{"state"})
var reposFetched = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "indexer_repos_fetched_count",
Help: "Number of repos fetched",
}, []string{"remote", "success"})
var reposIndexed = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "indexer_repos_indexed_count",
Help: "Number of repos indexed",
}, []string{"success"})
var recordsFetched = promauto.NewCounter(prometheus.CounterOpts{
Name: "indexer_records_fetched_count",
Help: "Number of records fetched",
})
var recordsInserted = promauto.NewCounter(prometheus.CounterOpts{
Name: "indexer_records_inserted_count",
Help: "Number of records inserted into DB",
})
var workerPoolSize = promauto.NewGauge(prometheus.GaugeOpts{
Name: "indexer_workers_count",
Help: "Current number of workers running",
})

View file

@ -0,0 +1,82 @@
package main
import (
"context"
"fmt"
"sync"
"github.com/rs/zerolog"
"github.com/uabluerail/indexer/pds"
"golang.org/x/time/rate"
"gorm.io/gorm"
)
const defaultRateLimit = 10
type Limiter struct {
mu sync.RWMutex
db *gorm.DB
limiter map[string]*rate.Limiter
}
func NewLimiter(db *gorm.DB) (*Limiter, error) {
remotes := []pds.PDS{}
if err := db.Find(&remotes).Error; err != nil {
return nil, fmt.Errorf("failed to get the list of known PDSs: %w", err)
}
l := &Limiter{
db: db,
limiter: map[string]*rate.Limiter{},
}
for _, remote := range remotes {
limit := remote.CrawlLimit
if limit == 0 {
limit = defaultRateLimit
}
l.limiter[remote.Host] = rate.NewLimiter(rate.Limit(limit), limit*2)
}
return l, nil
}
func (l *Limiter) getLimiter(name string) *rate.Limiter {
l.mu.RLock()
limiter := l.limiter[name]
l.mu.RUnlock()
if limiter != nil {
return limiter
}
limiter = rate.NewLimiter(defaultRateLimit, defaultRateLimit*2)
l.mu.Lock()
l.limiter[name] = limiter
l.mu.Unlock()
return limiter
}
func (l *Limiter) Wait(ctx context.Context, name string) error {
return l.getLimiter(name).Wait(ctx)
}
func (l *Limiter) SetLimit(ctx context.Context, name string, limit rate.Limit) {
l.getLimiter(name).SetLimit(limit)
err := l.db.Model(&pds.PDS{}).Where(&pds.PDS{Host: name}).Updates(&pds.PDS{CrawlLimit: int(limit)}).Error
if err != nil {
zerolog.Ctx(ctx).Error().Err(err).Msgf("Failed to persist rate limit change for %q: %s", name, err)
}
}
func (l *Limiter) SetAllLimits(ctx context.Context, limit rate.Limit) {
l.mu.RLock()
for name, limiter := range l.limiter {
limiter.SetLimit(limit)
err := l.db.Model(&pds.PDS{}).Where(&pds.PDS{Host: name}).Updates(&pds.PDS{CrawlLimit: int(limit)}).Error
if err != nil {
zerolog.Ctx(ctx).Error().Err(err).Msgf("Failed to persist rate limit change for %q: %s", name, err)
}
}
l.mu.RUnlock()
}

View file

@ -0,0 +1,136 @@
package main
import (
"context"
"fmt"
"time"
"github.com/rs/zerolog"
"github.com/uabluerail/indexer/pds"
"github.com/uabluerail/indexer/repo"
"gorm.io/gorm"
)
type Scheduler struct {
db *gorm.DB
output chan<- WorkItem
queue map[string]*repo.Repo
inProgress map[string]*repo.Repo
}
func NewScheduler(output chan<- WorkItem, db *gorm.DB) *Scheduler {
return &Scheduler{
db: db,
output: output,
queue: map[string]*repo.Repo{},
inProgress: map[string]*repo.Repo{},
}
}
func (s *Scheduler) Start(ctx context.Context) error {
go s.run(ctx)
return nil
}
func (s *Scheduler) run(ctx context.Context) {
log := zerolog.Ctx(ctx)
t := time.NewTicker(time.Minute)
defer t.Stop()
if err := s.fillQueue(ctx); err != nil {
log.Error().Err(err).Msgf("Failed to get more tasks for the queue: %s", err)
}
done := make(chan string)
for {
if len(s.queue) > 0 {
next := WorkItem{signal: make(chan struct{})}
for _, r := range s.queue {
next.Repo = r
break
}
select {
case <-ctx.Done():
return
case <-t.C:
if err := s.fillQueue(ctx); err != nil {
log.Error().Err(err).Msgf("Failed to get more tasks for the queue: %s", err)
}
case s.output <- next:
delete(s.queue, next.Repo.DID)
s.inProgress[next.Repo.DID] = next.Repo
go func(did string, ch chan struct{}) {
select {
case <-ch:
case <-ctx.Done():
}
done <- did
}(next.Repo.DID, next.signal)
s.updateQueueLenMetrics()
case did := <-done:
delete(s.inProgress, did)
s.updateQueueLenMetrics()
}
} else {
select {
case <-ctx.Done():
return
case <-t.C:
if err := s.fillQueue(ctx); err != nil {
log.Error().Err(err).Msgf("Failed to get more tasks for the queue: %s", err)
}
case did := <-done:
delete(s.inProgress, did)
s.updateQueueLenMetrics()
}
}
}
}
func (s *Scheduler) fillQueue(ctx context.Context) error {
const maxQueueLen = 10000
if len(s.queue)+len(s.inProgress) >= maxQueueLen {
return nil
}
remotes := []pds.PDS{}
if err := s.db.Find(&remotes).Error; err != nil {
return fmt.Errorf("failed to get the list of PDSs: %w", err)
}
perPDSLimit := maxQueueLen * 2 / len(remotes)
for _, remote := range remotes {
repos := []repo.Repo{}
err := s.db.Model(&repos).
Where(`pds = ? AND (
last_indexed_rev is null OR last_indexed_rev = ''
OR (first_rev_since_reset is not null AND first_rev_since_reset <> '' AND last_indexed_rev < first_rev_since_reset))`,
remote.ID).
Limit(perPDSLimit).Find(&repos).Error
if err != nil {
return fmt.Errorf("querying DB: %w", err)
}
for _, r := range repos {
if s.queue[r.DID] != nil || s.inProgress[r.DID] != nil {
continue
}
copied := r
s.queue[r.DID] = &copied
reposQueued.Inc()
}
s.updateQueueLenMetrics()
}
return nil
}
func (s *Scheduler) updateQueueLenMetrics() {
queueLenght.WithLabelValues("queued").Set(float64(len(s.queue)))
queueLenght.WithLabelValues("inProgress").Set(float64(len(s.inProgress)))
}

View file

@ -0,0 +1,238 @@
package main
import (
"bytes"
"context"
"fmt"
"net/url"
"strings"
"time"
"github.com/imax9000/errors"
"github.com/rs/zerolog"
"gorm.io/gorm"
"gorm.io/gorm/clause"
comatproto "github.com/bluesky-social/indigo/api/atproto"
"github.com/bluesky-social/indigo/xrpc"
"github.com/uabluerail/bsky-tools/xrpcauth"
"github.com/uabluerail/indexer/models"
"github.com/uabluerail/indexer/repo"
"github.com/uabluerail/indexer/util/resolver"
)
type WorkItem struct {
Repo *repo.Repo
signal chan struct{}
}
type WorkerPool struct {
db *gorm.DB
input <-chan WorkItem
limiter *Limiter
workerSignals []chan struct{}
resize chan int
}
func NewWorkerPool(input <-chan WorkItem, db *gorm.DB, size int, limiter *Limiter) *WorkerPool {
r := &WorkerPool{
db: db,
input: input,
limiter: limiter,
resize: make(chan int),
}
r.workerSignals = make([]chan struct{}, size)
for i := range r.workerSignals {
r.workerSignals[i] = make(chan struct{})
}
return r
}
func (p *WorkerPool) Start(ctx context.Context) error {
go p.run(ctx)
return nil
}
func (p *WorkerPool) Resize(ctx context.Context, size int) error {
select {
case <-ctx.Done():
return ctx.Err()
case p.resize <- size:
return nil
}
}
func (p *WorkerPool) run(ctx context.Context) {
for _, ch := range p.workerSignals {
go p.worker(ctx, ch)
}
workerPoolSize.Set(float64(len(p.workerSignals)))
for {
select {
case <-ctx.Done():
for _, ch := range p.workerSignals {
close(ch)
}
// also wait for all workers to stop?
case newSize := <-p.resize:
switch {
case newSize > len(p.workerSignals):
ch := make([]chan struct{}, newSize-len(p.workerSignals))
for i := range ch {
ch[i] = make(chan struct{})
go p.worker(ctx, ch[i])
}
p.workerSignals = append(p.workerSignals, ch...)
workerPoolSize.Set(float64(len(p.workerSignals)))
case newSize < len(p.workerSignals) && newSize > 0:
for _, ch := range p.workerSignals[newSize:] {
close(ch)
}
p.workerSignals = p.workerSignals[:newSize]
workerPoolSize.Set(float64(len(p.workerSignals)))
}
}
}
}
func (p *WorkerPool) worker(ctx context.Context, signal chan struct{}) {
log := zerolog.Ctx(ctx)
for {
select {
case <-ctx.Done():
return
case <-signal:
return
case work := <-p.input:
updates := &repo.Repo{}
if err := p.doWork(ctx, work); err != nil {
log.Error().Err(err).Msgf("Work task %q failed: %s", work.Repo.DID, err)
updates.LastError = err.Error()
reposIndexed.WithLabelValues("false").Inc()
} else {
reposIndexed.WithLabelValues("true").Inc()
}
updates.LastIndexAttempt = time.Now()
err := p.db.Model(&repo.Repo{}).
Where(&repo.Repo{Model: gorm.Model{ID: work.Repo.ID}}).
Updates(updates).Error
if err != nil {
log.Error().Err(err).Msgf("Failed to update repo info for %q: %s", work.Repo.DID, err)
}
}
}
}
func (p *WorkerPool) doWork(ctx context.Context, work WorkItem) error {
log := zerolog.Ctx(ctx)
defer close(work.signal)
doc, err := resolver.GetDocument(ctx, work.Repo.DID)
if err != nil {
return fmt.Errorf("resolving did %q: %w", work.Repo.DID, err)
}
pdsHost := ""
for _, srv := range doc.Service {
if srv.Type != "AtprotoPersonalDataServer" {
continue
}
pdsHost = srv.ServiceEndpoint
}
if pdsHost == "" {
return fmt.Errorf("did not find any PDS in DID Document")
}
u, err := url.Parse(pdsHost)
if err != nil {
return fmt.Errorf("PDS endpoint (%q) is an invalid URL: %w", pdsHost, err)
}
if u.Host == "" {
return fmt.Errorf("PDS endpoint (%q) doesn't have a host part", pdsHost)
}
client := xrpcauth.NewAnonymousClient(ctx)
client.Host = u.String()
retry:
if p.limiter != nil {
if err := p.limiter.Wait(ctx, u.String()); err != nil {
return fmt.Errorf("failed to wait on rate limiter: %w", err)
}
}
b, err := comatproto.SyncGetRepo(ctx, client, work.Repo.DID, "")
if err != nil {
if err, ok := errors.As[*xrpc.Error](err); ok {
if err.IsThrottled() && err.Ratelimit != nil {
log.Debug().Str("pds", u.String()).Msgf("Hit a rate limit (%s), sleeping until %s", err.Ratelimit.Policy, err.Ratelimit.Reset)
time.Sleep(time.Until(err.Ratelimit.Reset))
goto retry
}
}
reposFetched.WithLabelValues(u.String(), "false").Inc()
return fmt.Errorf("failed to fetch repo: %w", err)
}
reposFetched.WithLabelValues(u.String(), "true").Inc()
newRev, err := repo.GetRev(ctx, bytes.NewReader(b))
if err != nil {
return fmt.Errorf("failed to read 'rev' from the fetched repo: %w", err)
}
newRecs, err := repo.ExtractRecords(ctx, bytes.NewReader(b))
if err != nil {
return fmt.Errorf("failed to extract records: %w", err)
}
recs := []repo.Record{}
for k, v := range newRecs {
parts := strings.SplitN(k, "/", 2)
if len(parts) != 2 {
log.Warn().Msgf("Unexpected key format: %q", k)
continue
}
recs = append(recs, repo.Record{
Repo: models.ID(work.Repo.ID),
Collection: parts[0],
Rkey: parts[1],
Content: v,
})
}
recordsFetched.Add(float64(len(recs)))
if len(recs) > 0 {
for _, batch := range splitInBatshes(recs, 500) {
result := p.db.Model(&repo.Record{}).
Clauses(clause.OnConflict{DoUpdates: clause.AssignmentColumns([]string{"content"}),
Columns: []clause.Column{{Name: "repo"}, {Name: "collection"}, {Name: "rkey"}}}).
Create(batch)
if err := result.Error; err != nil {
return fmt.Errorf("inserting records into the database: %w", err)
}
recordsInserted.Add(float64(result.RowsAffected))
}
}
err = p.db.Model(&repo.Repo{}).Where(&repo.Repo{Model: gorm.Model{ID: work.Repo.ID}}).
Updates(&repo.Repo{LastIndexedRev: newRev}).Error
if err != nil {
return fmt.Errorf("updating repo rev: %w", err)
}
return nil
}
func splitInBatshes[T any](s []T, batchSize int) [][]T {
var r [][]T
for i := 0; i < len(s); i += batchSize {
if i+batchSize < len(s) {
r = append(r, s[i:i+batchSize])
} else {
r = append(r, s[i:])
}
}
return r
}