Import
This commit is contained in:
parent
2b6abac607
commit
63a767d890
25 changed files with 3027 additions and 0 deletions
14
cmd/consumer/Dockerfile
Normal file
14
cmd/consumer/Dockerfile
Normal file
|
@ -0,0 +1,14 @@
|
|||
FROM golang:1.21.1 as builder
|
||||
WORKDIR /app
|
||||
COPY go.mod go.sum ./
|
||||
RUN go mod download
|
||||
COPY . ./
|
||||
RUN go build -trimpath ./cmd/consumer
|
||||
|
||||
FROM alpine:latest as certs
|
||||
RUN apk --update add ca-certificates
|
||||
|
||||
FROM debian:stable-slim
|
||||
COPY --from=certs /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ca-certificates.crt
|
||||
COPY --from=builder /app/consumer .
|
||||
ENTRYPOINT ["./consumer"]
|
431
cmd/consumer/consumer.go
Normal file
431
cmd/consumer/consumer.go
Normal file
|
@ -0,0 +1,431 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/rs/zerolog"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
|
||||
comatproto "github.com/bluesky-social/indigo/api/atproto"
|
||||
"github.com/bluesky-social/indigo/xrpc"
|
||||
"github.com/ipld/go-ipld-prime/codec/dagcbor"
|
||||
"github.com/ipld/go-ipld-prime/datamodel"
|
||||
"github.com/ipld/go-ipld-prime/node/basicnode"
|
||||
|
||||
"github.com/uabluerail/indexer/models"
|
||||
"github.com/uabluerail/indexer/pds"
|
||||
"github.com/uabluerail/indexer/repo"
|
||||
)
|
||||
|
||||
type Consumer struct {
|
||||
db *gorm.DB
|
||||
remote pds.PDS
|
||||
|
||||
lastCursorPersist time.Time
|
||||
}
|
||||
|
||||
func NewConsumer(ctx context.Context, remote *pds.PDS, db *gorm.DB) (*Consumer, error) {
|
||||
return &Consumer{
|
||||
db: db,
|
||||
remote: *remote,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Consumer) Start(ctx context.Context) error {
|
||||
go c.run(ctx)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Consumer) run(ctx context.Context) {
|
||||
log := zerolog.Ctx(ctx).With().Str("pds", c.remote.Host).Logger()
|
||||
ctx = log.WithContext(ctx)
|
||||
|
||||
for {
|
||||
if err := c.runOnce(ctx); err != nil {
|
||||
log.Error().Err(err).Msgf("Consumer of %q failed (will be restarted): %s", c.remote.Host, err)
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Consumer) runOnce(ctx context.Context) error {
|
||||
log := zerolog.Ctx(ctx)
|
||||
|
||||
addr, err := url.Parse(c.remote.Host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing URL %q: %s", c.remote.Host, err)
|
||||
}
|
||||
addr.Scheme = "wss"
|
||||
addr.Path = path.Join(addr.Path, "xrpc/com.atproto.sync.subscribeRepos")
|
||||
|
||||
if c.remote.Cursor > 0 {
|
||||
params := url.Values{"cursor": []string{fmt.Sprint(c.remote.Cursor)}}
|
||||
addr.RawQuery = params.Encode()
|
||||
}
|
||||
|
||||
conn, _, err := websocket.DefaultDialer.DialContext(ctx, addr.String(), http.Header{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("establishing websocker connection: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
ch := make(chan bool)
|
||||
defer close(ch)
|
||||
go func() {
|
||||
t := time.NewTicker(time.Minute)
|
||||
for {
|
||||
select {
|
||||
case <-ch:
|
||||
return
|
||||
case <-t.C:
|
||||
if err := conn.WriteControl(websocket.PingMessage, []byte("ping"), time.Now().Add(time.Minute)); err != nil {
|
||||
log.Error().Err(err).Msgf("Failed to send ping: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
first := true
|
||||
for {
|
||||
_, b, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
return fmt.Errorf("websocket.ReadMessage: %w", err)
|
||||
}
|
||||
|
||||
r := bytes.NewReader(b)
|
||||
proto := basicnode.Prototype.Any
|
||||
headerNode := proto.NewBuilder()
|
||||
if err := (&dagcbor.DecodeOptions{DontParseBeyondEnd: true}).Decode(headerNode, r); err != nil {
|
||||
return fmt.Errorf("unmarshaling message header: %w", err)
|
||||
}
|
||||
header, err := parseHeader(headerNode.Build())
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing message header: %w", err)
|
||||
}
|
||||
switch header.Op {
|
||||
case 1:
|
||||
if err := c.processMessage(ctx, header.Type, r, first); err != nil {
|
||||
return err
|
||||
}
|
||||
case -1:
|
||||
bodyNode := proto.NewBuilder()
|
||||
if err := (&dagcbor.DecodeOptions{DontParseBeyondEnd: true, AllowLinks: true}).Decode(bodyNode, r); err != nil {
|
||||
return fmt.Errorf("unmarshaling message body: %w", err)
|
||||
}
|
||||
body, err := parseError(bodyNode.Build())
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing error payload: %w", err)
|
||||
}
|
||||
return &body
|
||||
default:
|
||||
log.Warn().Msgf("Unknown 'op' value received: %d", header.Op)
|
||||
}
|
||||
first = false
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Consumer) checkForCursorReset(ctx context.Context, seq int64) error {
|
||||
// hack to detect cursor resets upon connection for implementations
|
||||
// that don't emit an explicit #info when connecting with an outdated cursor.
|
||||
|
||||
if seq == c.remote.Cursor+1 {
|
||||
// No reset.
|
||||
return nil
|
||||
}
|
||||
|
||||
return c.resetCursor(ctx, seq)
|
||||
}
|
||||
|
||||
func (c *Consumer) resetCursor(ctx context.Context, seq int64) error {
|
||||
zerolog.Ctx(ctx).Warn().Str("pds", c.remote.Host).Msgf("Cursor reset: %d -> %d", c.remote.Cursor, seq)
|
||||
err := c.db.Model(&c.remote).
|
||||
Where(&pds.PDS{Model: gorm.Model{ID: c.remote.ID}}).
|
||||
Updates(&pds.PDS{FirstCursorSinceReset: seq}).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("updating FirstCursorSinceReset: %w", err)
|
||||
}
|
||||
c.remote.FirstCursorSinceReset = seq
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Consumer) updateCursor(ctx context.Context, seq int64) error {
|
||||
if math.Abs(float64(seq-c.remote.Cursor)) < 100 && time.Since(c.lastCursorPersist) < 5*time.Second {
|
||||
c.remote.Cursor = seq
|
||||
return nil
|
||||
}
|
||||
|
||||
err := c.db.Model(&c.remote).
|
||||
Where(&pds.PDS{Model: gorm.Model{ID: c.remote.ID}}).
|
||||
Updates(&pds.PDS{Cursor: seq}).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("updating Cursor: %w", err)
|
||||
}
|
||||
c.remote.Cursor = seq
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
func (c *Consumer) processMessage(ctx context.Context, typ string, r io.Reader, first bool) error {
|
||||
log := zerolog.Ctx(ctx)
|
||||
|
||||
switch typ {
|
||||
case "#commit":
|
||||
payload := &comatproto.SyncSubscribeRepos_Commit{}
|
||||
if err := payload.UnmarshalCBOR(r); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal commit: %w", err)
|
||||
}
|
||||
|
||||
if c.remote.FirstCursorSinceReset == 0 {
|
||||
if err := c.resetCursor(ctx, payload.Seq); err != nil {
|
||||
return fmt.Errorf("handling cursor reset: %w", err)
|
||||
}
|
||||
}
|
||||
if first {
|
||||
if err := c.checkForCursorReset(ctx, payload.Seq); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
repoInfo, err := repo.EnsureExists(ctx, c.db, payload.Repo)
|
||||
if err != nil {
|
||||
return fmt.Errorf("repo.EnsureExists(%q): %w", payload.Repo, err)
|
||||
}
|
||||
if repoInfo.PDS != models.ID(c.remote.ID) {
|
||||
log.Error().Str("did", payload.Repo).Str("rev", payload.Rev).
|
||||
Msgf("Commit from an incorrect PDS, skipping")
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: verify signature
|
||||
|
||||
expectRecords := false
|
||||
deletions := []string{}
|
||||
for _, op := range payload.Ops {
|
||||
switch op.Action {
|
||||
case "create":
|
||||
expectRecords = true
|
||||
case "update":
|
||||
expectRecords = true
|
||||
case "delete":
|
||||
deletions = append(deletions, op.Path)
|
||||
}
|
||||
}
|
||||
for _, d := range deletions {
|
||||
parts := strings.SplitN(d, "/", 2)
|
||||
if len(parts) != 2 {
|
||||
continue
|
||||
}
|
||||
err := c.db.Model(&repo.Record{}).
|
||||
Where(&repo.Record{
|
||||
Repo: models.ID(repoInfo.ID),
|
||||
Collection: parts[0],
|
||||
Rkey: parts[1]}).
|
||||
Updates(&repo.Record{Deleted: true}).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to mark %s/%s as deleted: %w", payload.Repo, d, err)
|
||||
}
|
||||
}
|
||||
|
||||
newRecs, err := repo.ExtractRecords(ctx, bytes.NewReader(payload.Blocks))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to extract records: %w", err)
|
||||
}
|
||||
recs := []repo.Record{}
|
||||
for k, v := range newRecs {
|
||||
parts := strings.SplitN(k, "/", 2)
|
||||
if len(parts) != 2 {
|
||||
log.Warn().Msgf("Unexpected key format: %q", k)
|
||||
continue
|
||||
}
|
||||
recs = append(recs, repo.Record{
|
||||
Repo: models.ID(repoInfo.ID),
|
||||
Collection: parts[0],
|
||||
Rkey: parts[1],
|
||||
Content: v,
|
||||
})
|
||||
}
|
||||
if len(recs) == 0 && expectRecords {
|
||||
log.Debug().Int64("seq", payload.Seq).Str("pds", c.remote.Host).Msgf("len(recs) == 0")
|
||||
}
|
||||
if len(recs) > 0 || expectRecords {
|
||||
err = c.db.Model(&repo.Record{}).
|
||||
Clauses(clause.OnConflict{DoUpdates: clause.AssignmentColumns([]string{"content"}),
|
||||
Columns: []clause.Column{{Name: "repo"}, {Name: "collection"}, {Name: "rkey"}}}).
|
||||
Create(recs).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("inserting records into the database: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if payload.TooBig {
|
||||
// Just trigger a re-index by resetting rev.
|
||||
err := c.db.Model(r).Where(&repo.Repo{Model: gorm.Model{ID: repoInfo.ID}}).
|
||||
Updates(&repo.Repo{
|
||||
FirstCursorSinceReset: c.remote.FirstCursorSinceReset,
|
||||
FirstRevSinceReset: payload.Rev,
|
||||
}).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update repo info after cursor reset: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if repoInfo.FirstCursorSinceReset != c.remote.FirstCursorSinceReset {
|
||||
err := c.db.Model(r).Where(&repo.Repo{Model: gorm.Model{ID: repoInfo.ID}}).
|
||||
Updates(&repo.Repo{
|
||||
FirstCursorSinceReset: c.remote.FirstCursorSinceReset,
|
||||
FirstRevSinceReset: payload.Rev,
|
||||
}).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update repo info after cursor reset: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := c.updateCursor(ctx, payload.Seq); err != nil {
|
||||
return err
|
||||
}
|
||||
case "#handle":
|
||||
payload := &comatproto.SyncSubscribeRepos_Handle{}
|
||||
if err := payload.UnmarshalCBOR(r); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal commit: %w", err)
|
||||
}
|
||||
|
||||
if c.remote.FirstCursorSinceReset == 0 {
|
||||
if err := c.resetCursor(ctx, payload.Seq); err != nil {
|
||||
return fmt.Errorf("handling cursor reset: %w", err)
|
||||
}
|
||||
}
|
||||
if first {
|
||||
if err := c.checkForCursorReset(ctx, payload.Seq); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// No-op, we don't store handles.
|
||||
if err := c.updateCursor(ctx, payload.Seq); err != nil {
|
||||
return err
|
||||
}
|
||||
case "#migrate":
|
||||
payload := &comatproto.SyncSubscribeRepos_Migrate{}
|
||||
if err := payload.UnmarshalCBOR(r); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal commit: %w", err)
|
||||
}
|
||||
|
||||
if c.remote.FirstCursorSinceReset == 0 {
|
||||
if err := c.resetCursor(ctx, payload.Seq); err != nil {
|
||||
return fmt.Errorf("handling cursor reset: %w", err)
|
||||
}
|
||||
}
|
||||
if first {
|
||||
if err := c.checkForCursorReset(ctx, payload.Seq); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug().Interface("payload", payload).Str("did", payload.Did).Msgf("MIGRATION")
|
||||
// TODO
|
||||
if err := c.updateCursor(ctx, payload.Seq); err != nil {
|
||||
return err
|
||||
}
|
||||
case "#tombstone":
|
||||
payload := &comatproto.SyncSubscribeRepos_Tombstone{}
|
||||
if err := payload.UnmarshalCBOR(r); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal commit: %w", err)
|
||||
}
|
||||
|
||||
if c.remote.FirstCursorSinceReset == 0 {
|
||||
if err := c.resetCursor(ctx, payload.Seq); err != nil {
|
||||
return fmt.Errorf("handling cursor reset: %w", err)
|
||||
}
|
||||
}
|
||||
if first {
|
||||
if err := c.checkForCursorReset(ctx, payload.Seq); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// TODO
|
||||
if err := c.updateCursor(ctx, payload.Seq); err != nil {
|
||||
return err
|
||||
}
|
||||
case "#info":
|
||||
payload := &comatproto.SyncSubscribeRepos_Info{}
|
||||
if err := payload.UnmarshalCBOR(r); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal commit: %w", err)
|
||||
}
|
||||
switch payload.Name {
|
||||
case "OutdatedCursor":
|
||||
if !first {
|
||||
log.Warn().Msgf("Received cursor reset notification in the middle of a stream: %+v", payload)
|
||||
}
|
||||
c.remote.FirstCursorSinceReset = 0
|
||||
default:
|
||||
log.Error().Msgf("Unknown #info message %q: %+v", payload.Name, payload)
|
||||
}
|
||||
default:
|
||||
log.Warn().Msgf("Unknown message type received: %s", typ)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type Header struct {
|
||||
Op int64
|
||||
Type string
|
||||
}
|
||||
|
||||
func parseHeader(node datamodel.Node) (Header, error) {
|
||||
r := Header{}
|
||||
op, err := node.LookupByString("op")
|
||||
if err != nil {
|
||||
return r, fmt.Errorf("missing 'op': %w", err)
|
||||
}
|
||||
r.Op, err = op.AsInt()
|
||||
if err != nil {
|
||||
return r, fmt.Errorf("op.AsInt(): %w", err)
|
||||
}
|
||||
if r.Op == -1 {
|
||||
// Error frame, type should not be present
|
||||
return r, nil
|
||||
}
|
||||
t, err := node.LookupByString("t")
|
||||
if err != nil {
|
||||
return r, fmt.Errorf("missing 't': %w", err)
|
||||
}
|
||||
r.Type, err = t.AsString()
|
||||
if err != nil {
|
||||
return r, fmt.Errorf("t.AsString(): %w", err)
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func parseError(node datamodel.Node) (xrpc.XRPCError, error) {
|
||||
r := xrpc.XRPCError{}
|
||||
e, err := node.LookupByString("error")
|
||||
if err != nil {
|
||||
return r, fmt.Errorf("missing 'error': %w", err)
|
||||
}
|
||||
r.ErrStr, err = e.AsString()
|
||||
if err != nil {
|
||||
return r, fmt.Errorf("error.AsString(): %w", err)
|
||||
}
|
||||
m, err := node.LookupByString("message")
|
||||
if err == nil {
|
||||
r.Message, err = m.AsString()
|
||||
if err != nil {
|
||||
return r, fmt.Errorf("message.AsString(): %w", err)
|
||||
}
|
||||
} else if !errors.Is(err, datamodel.ErrNotExists{}) {
|
||||
return r, fmt.Errorf("looking up 'message': %w", err)
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
176
cmd/consumer/main.go
Normal file
176
cmd/consumer/main.go
Normal file
|
@ -0,0 +1,176 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
_ "github.com/joho/godotenv/autoload"
|
||||
"github.com/kelseyhightower/envconfig"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"github.com/rs/zerolog"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"github.com/uabluerail/indexer/pds"
|
||||
"github.com/uabluerail/indexer/util/gormzerolog"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
LogFile string
|
||||
LogFormat string `default:"text"`
|
||||
LogLevel int64 `default:"1"`
|
||||
MetricsPort string `split_words:"true"`
|
||||
DBUrl string `envconfig:"POSTGRES_URL"`
|
||||
}
|
||||
|
||||
var config Config
|
||||
|
||||
func runMain(ctx context.Context) error {
|
||||
ctx = setupLogging(ctx)
|
||||
log := zerolog.Ctx(ctx)
|
||||
log.Debug().Msgf("Starting up...")
|
||||
db, err := gorm.Open(postgres.Open(config.DBUrl), &gorm.Config{
|
||||
Logger: gormzerolog.New(&logger.Config{
|
||||
SlowThreshold: 3 * time.Second,
|
||||
IgnoreRecordNotFoundError: true,
|
||||
}, nil),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("connecting to the database: %w", err)
|
||||
}
|
||||
log.Debug().Msgf("DB connection established")
|
||||
|
||||
remotes := []pds.PDS{}
|
||||
if err := db.Find(&remotes).Error; err != nil {
|
||||
return fmt.Errorf("listing known PDSs: %w", err)
|
||||
}
|
||||
// TODO: check for changes and start/stop consumers as needed
|
||||
for _, remote := range remotes {
|
||||
c, err := NewConsumer(ctx, &remote, db)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create a consumer for %q: %w", remote.Host, err)
|
||||
}
|
||||
if err := c.Start(ctx); err != nil {
|
||||
return fmt.Errorf("failed ot start a consumer for %q: %w", remote.Host, err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Info().Msgf("Starting HTTP listener on %q...", config.MetricsPort)
|
||||
http.Handle("/metrics", promhttp.Handler())
|
||||
srv := &http.Server{Addr: fmt.Sprintf(":%s", config.MetricsPort)}
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
errCh <- srv.ListenAndServe()
|
||||
}()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if err := srv.Shutdown(context.Background()); err != nil {
|
||||
return fmt.Errorf("HTTP server shutdown failed: %w", err)
|
||||
}
|
||||
}
|
||||
return <-errCh
|
||||
}
|
||||
|
||||
func main() {
|
||||
flag.StringVar(&config.LogFile, "log", "", "Path to the log file. If empty, will log to stderr")
|
||||
flag.StringVar(&config.LogFormat, "log-format", "text", "Logging format. 'text' or 'json'")
|
||||
flag.Int64Var(&config.LogLevel, "log-level", 1, "Log level. -1 - trace, 0 - debug, 1 - info, 5 - panic")
|
||||
|
||||
if err := envconfig.Process("consumer", &config); err != nil {
|
||||
log.Fatalf("envconfig.Process: %s", err)
|
||||
}
|
||||
|
||||
flag.Parse()
|
||||
|
||||
ctx, _ := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
|
||||
if err := runMain(ctx); err != nil {
|
||||
fmt.Fprintln(os.Stderr, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func setupLogging(ctx context.Context) context.Context {
|
||||
logFile := os.Stderr
|
||||
|
||||
if config.LogFile != "" {
|
||||
f, err := os.OpenFile(config.LogFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to open the specified log file %q: %s", config.LogFile, err)
|
||||
}
|
||||
logFile = f
|
||||
}
|
||||
|
||||
var output io.Writer
|
||||
|
||||
switch config.LogFormat {
|
||||
case "json":
|
||||
output = logFile
|
||||
case "text":
|
||||
prefixList := []string{}
|
||||
info, ok := debug.ReadBuildInfo()
|
||||
if ok {
|
||||
prefixList = append(prefixList, info.Path+"/")
|
||||
}
|
||||
|
||||
basedir := ""
|
||||
_, sourceFile, _, ok := runtime.Caller(0)
|
||||
if ok {
|
||||
basedir = filepath.Dir(sourceFile)
|
||||
}
|
||||
|
||||
if basedir != "" && strings.HasPrefix(basedir, "/") {
|
||||
prefixList = append(prefixList, basedir+"/")
|
||||
head, _ := filepath.Split(basedir)
|
||||
for head != "/" {
|
||||
prefixList = append(prefixList, head)
|
||||
head, _ = filepath.Split(strings.TrimSuffix(head, "/"))
|
||||
}
|
||||
}
|
||||
|
||||
output = zerolog.ConsoleWriter{
|
||||
Out: logFile,
|
||||
NoColor: true,
|
||||
TimeFormat: time.RFC3339,
|
||||
PartsOrder: []string{
|
||||
zerolog.LevelFieldName,
|
||||
zerolog.TimestampFieldName,
|
||||
zerolog.CallerFieldName,
|
||||
zerolog.MessageFieldName,
|
||||
},
|
||||
FormatFieldName: func(i interface{}) string { return fmt.Sprintf("%s:", i) },
|
||||
FormatFieldValue: func(i interface{}) string { return fmt.Sprintf("%s", i) },
|
||||
FormatCaller: func(i interface{}) string {
|
||||
s := i.(string)
|
||||
for _, p := range prefixList {
|
||||
s = strings.TrimPrefix(s, p)
|
||||
}
|
||||
return s
|
||||
},
|
||||
}
|
||||
default:
|
||||
log.Fatalf("Invalid log format specified: %q", config.LogFormat)
|
||||
}
|
||||
|
||||
logger := zerolog.New(output).Level(zerolog.Level(config.LogLevel)).With().Caller().Timestamp().Logger()
|
||||
|
||||
ctx = logger.WithContext(ctx)
|
||||
|
||||
zerolog.DefaultContextLogger = &logger
|
||||
log.SetOutput(logger)
|
||||
|
||||
return ctx
|
||||
}
|
14
cmd/lister/Dockerfile
Normal file
14
cmd/lister/Dockerfile
Normal file
|
@ -0,0 +1,14 @@
|
|||
FROM golang:1.21.1 as builder
|
||||
WORKDIR /app
|
||||
COPY go.mod go.sum ./
|
||||
RUN go mod download
|
||||
COPY . ./
|
||||
RUN go build -trimpath ./cmd/lister
|
||||
|
||||
FROM alpine:latest as certs
|
||||
RUN apk --update add ca-certificates
|
||||
|
||||
FROM debian:stable-slim
|
||||
COPY --from=certs /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ca-certificates.crt
|
||||
COPY --from=builder /app/lister .
|
||||
ENTRYPOINT ["./lister"]
|
108
cmd/lister/lister.go
Normal file
108
cmd/lister/lister.go
Normal file
|
@ -0,0 +1,108 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"gorm.io/gorm"
|
||||
|
||||
comatproto "github.com/bluesky-social/indigo/api/atproto"
|
||||
"github.com/bluesky-social/indigo/did"
|
||||
|
||||
"github.com/uabluerail/bsky-tools/pagination"
|
||||
"github.com/uabluerail/bsky-tools/xrpcauth"
|
||||
"github.com/uabluerail/indexer/pds"
|
||||
"github.com/uabluerail/indexer/repo"
|
||||
"github.com/uabluerail/indexer/util/resolver"
|
||||
)
|
||||
|
||||
type Lister struct {
|
||||
db *gorm.DB
|
||||
resolver did.Resolver
|
||||
|
||||
pollInterval time.Duration
|
||||
listRefreshInterval time.Duration
|
||||
}
|
||||
|
||||
func NewLister(ctx context.Context, db *gorm.DB) (*Lister, error) {
|
||||
return &Lister{
|
||||
db: db,
|
||||
resolver: resolver.Resolver,
|
||||
pollInterval: 5 * time.Minute,
|
||||
listRefreshInterval: 24 * time.Hour,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (l *Lister) Start(ctx context.Context) error {
|
||||
go l.run(ctx)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *Lister) run(ctx context.Context) {
|
||||
log := zerolog.Ctx(ctx)
|
||||
ticker := time.NewTicker(l.pollInterval)
|
||||
|
||||
log.Info().Msgf("Lister starting...")
|
||||
t := make(chan time.Time, 1)
|
||||
t <- time.Now()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Info().Msgf("Lister stopped (context expired)")
|
||||
return
|
||||
case <-t:
|
||||
db := l.db.WithContext(ctx)
|
||||
|
||||
remote := pds.PDS{}
|
||||
if err := db.Model(&remote).
|
||||
Where("last_list is null or last_list < ?", time.Now().Add(-l.listRefreshInterval)).
|
||||
Take(&remote).Error; err != nil {
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
log.Error().Err(err).Msgf("Failed to query DB for a PDS to list repos from: %s", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
client := xrpcauth.NewAnonymousClient(ctx)
|
||||
client.Host = remote.Host
|
||||
|
||||
log.Info().Msgf("Listing repos from %q...", remote.Host)
|
||||
dids, err := pagination.Reduce(
|
||||
func(cursor string) (resp *comatproto.SyncListRepos_Output, nextCursor string, err error) {
|
||||
resp, err = comatproto.SyncListRepos(ctx, client, cursor, 200)
|
||||
if err == nil && resp.Cursor != nil {
|
||||
nextCursor = *resp.Cursor
|
||||
}
|
||||
return
|
||||
},
|
||||
func(resp *comatproto.SyncListRepos_Output, acc []string) ([]string, error) {
|
||||
for _, repo := range resp.Repos {
|
||||
if repo == nil {
|
||||
continue
|
||||
}
|
||||
acc = append(acc, repo.Did)
|
||||
}
|
||||
return acc, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("Failed to list repos from %q: %s", remote.Host, err)
|
||||
break
|
||||
}
|
||||
log.Info().Msgf("Received %d DIDs from %q", len(dids), remote.Host)
|
||||
|
||||
for _, did := range dids {
|
||||
if _, err := repo.EnsureExists(ctx, l.db, did); err != nil {
|
||||
log.Error().Err(err).Msgf("Failed to ensure that we have a record for the repo %q: %s", did, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := db.Model(&remote).Updates(&pds.PDS{LastList: time.Now()}).Error; err != nil {
|
||||
log.Error().Err(err).Msgf("Failed to update the timestamp of last list for %q: %s", remote.Host, err)
|
||||
}
|
||||
case v := <-ticker.C:
|
||||
t <- v
|
||||
}
|
||||
}
|
||||
}
|
180
cmd/lister/main.go
Normal file
180
cmd/lister/main.go
Normal file
|
@ -0,0 +1,180 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
_ "github.com/joho/godotenv/autoload"
|
||||
"github.com/kelseyhightower/envconfig"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"github.com/rs/zerolog"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"github.com/uabluerail/indexer/pds"
|
||||
"github.com/uabluerail/indexer/repo"
|
||||
"github.com/uabluerail/indexer/util/gormzerolog"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
LogFile string
|
||||
LogFormat string `default:"text"`
|
||||
LogLevel int64 `default:"1"`
|
||||
MetricsPort string `split_words:"true"`
|
||||
DBUrl string `envconfig:"POSTGRES_URL"`
|
||||
}
|
||||
|
||||
var config Config
|
||||
|
||||
func runMain(ctx context.Context) error {
|
||||
ctx = setupLogging(ctx)
|
||||
log := zerolog.Ctx(ctx)
|
||||
log.Debug().Msgf("Starting up...")
|
||||
db, err := gorm.Open(postgres.Open(config.DBUrl), &gorm.Config{
|
||||
Logger: gormzerolog.New(&logger.Config{
|
||||
SlowThreshold: 1 * time.Second,
|
||||
IgnoreRecordNotFoundError: true,
|
||||
}, nil),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("connecting to the database: %w", err)
|
||||
}
|
||||
log.Debug().Msgf("DB connection established")
|
||||
|
||||
for _, f := range []func(*gorm.DB) error{
|
||||
pds.AutoMigrate,
|
||||
repo.AutoMigrate,
|
||||
} {
|
||||
if err := f(db); err != nil {
|
||||
return fmt.Errorf("auto-migrating DB schema: %w", err)
|
||||
}
|
||||
}
|
||||
log.Debug().Msgf("DB schema updated")
|
||||
|
||||
lister, err := NewLister(ctx, db)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create lister: %w", err)
|
||||
}
|
||||
if err := lister.Start(ctx); err != nil {
|
||||
return fmt.Errorf("failed to start lister: %w", err)
|
||||
}
|
||||
|
||||
log.Info().Msgf("Starting HTTP listener on %q...", config.MetricsPort)
|
||||
http.Handle("/metrics", promhttp.Handler())
|
||||
srv := &http.Server{Addr: fmt.Sprintf(":%s", config.MetricsPort)}
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
errCh <- srv.ListenAndServe()
|
||||
}()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if err := srv.Shutdown(context.Background()); err != nil {
|
||||
return fmt.Errorf("HTTP server shutdown failed: %w", err)
|
||||
}
|
||||
}
|
||||
return <-errCh
|
||||
}
|
||||
|
||||
func main() {
|
||||
flag.StringVar(&config.LogFile, "log", "", "Path to the log file. If empty, will log to stderr")
|
||||
flag.StringVar(&config.LogFormat, "log-format", "text", "Logging format. 'text' or 'json'")
|
||||
flag.Int64Var(&config.LogLevel, "log-level", 1, "Log level. -1 - trace, 0 - debug, 1 - info, 5 - panic")
|
||||
|
||||
if err := envconfig.Process("lister", &config); err != nil {
|
||||
log.Fatalf("envconfig.Process: %s", err)
|
||||
}
|
||||
|
||||
flag.Parse()
|
||||
|
||||
ctx, _ := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
|
||||
if err := runMain(ctx); err != nil {
|
||||
fmt.Fprintln(os.Stderr, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func setupLogging(ctx context.Context) context.Context {
|
||||
logFile := os.Stderr
|
||||
|
||||
if config.LogFile != "" {
|
||||
f, err := os.OpenFile(config.LogFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to open the specified log file %q: %s", config.LogFile, err)
|
||||
}
|
||||
logFile = f
|
||||
}
|
||||
|
||||
var output io.Writer
|
||||
|
||||
switch config.LogFormat {
|
||||
case "json":
|
||||
output = logFile
|
||||
case "text":
|
||||
prefixList := []string{}
|
||||
info, ok := debug.ReadBuildInfo()
|
||||
if ok {
|
||||
prefixList = append(prefixList, info.Path+"/")
|
||||
}
|
||||
|
||||
basedir := ""
|
||||
_, sourceFile, _, ok := runtime.Caller(0)
|
||||
if ok {
|
||||
basedir = filepath.Dir(sourceFile)
|
||||
}
|
||||
|
||||
if basedir != "" && strings.HasPrefix(basedir, "/") {
|
||||
prefixList = append(prefixList, basedir+"/")
|
||||
head, _ := filepath.Split(basedir)
|
||||
for head != "/" {
|
||||
prefixList = append(prefixList, head)
|
||||
head, _ = filepath.Split(strings.TrimSuffix(head, "/"))
|
||||
}
|
||||
}
|
||||
|
||||
output = zerolog.ConsoleWriter{
|
||||
Out: logFile,
|
||||
NoColor: true,
|
||||
TimeFormat: time.RFC3339,
|
||||
PartsOrder: []string{
|
||||
zerolog.LevelFieldName,
|
||||
zerolog.TimestampFieldName,
|
||||
zerolog.CallerFieldName,
|
||||
zerolog.MessageFieldName,
|
||||
},
|
||||
FormatFieldName: func(i interface{}) string { return fmt.Sprintf("%s:", i) },
|
||||
FormatFieldValue: func(i interface{}) string { return fmt.Sprintf("%s", i) },
|
||||
FormatCaller: func(i interface{}) string {
|
||||
s := i.(string)
|
||||
for _, p := range prefixList {
|
||||
s = strings.TrimPrefix(s, p)
|
||||
}
|
||||
return s
|
||||
},
|
||||
}
|
||||
default:
|
||||
log.Fatalf("Invalid log format specified: %q", config.LogFormat)
|
||||
}
|
||||
|
||||
logger := zerolog.New(output).Level(zerolog.Level(config.LogLevel)).With().Caller().Timestamp().Logger()
|
||||
|
||||
ctx = logger.WithContext(ctx)
|
||||
|
||||
zerolog.DefaultContextLogger = &logger
|
||||
log.SetOutput(logger)
|
||||
|
||||
return ctx
|
||||
}
|
14
cmd/record-indexer/Dockerfile
Normal file
14
cmd/record-indexer/Dockerfile
Normal file
|
@ -0,0 +1,14 @@
|
|||
FROM golang:1.21.1 as builder
|
||||
WORKDIR /app
|
||||
COPY go.mod go.sum ./
|
||||
RUN go mod download
|
||||
COPY . ./
|
||||
RUN go build -trimpath ./cmd/record-indexer
|
||||
|
||||
FROM alpine:latest as certs
|
||||
RUN apk --update add ca-certificates
|
||||
|
||||
FROM debian:stable-slim
|
||||
COPY --from=certs /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ca-certificates.crt
|
||||
COPY --from=builder /app/record-indexer .
|
||||
ENTRYPOINT ["./record-indexer"]
|
77
cmd/record-indexer/admin.go
Normal file
77
cmd/record-indexer/admin.go
Normal file
|
@ -0,0 +1,77 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
func AddAdminHandlers(limiter *Limiter, pool *WorkerPool) {
|
||||
http.HandleFunc("/rate/set", handleRateSet(limiter))
|
||||
http.HandleFunc("/rate/setAll", handleRateSetAll(limiter))
|
||||
http.HandleFunc("/pool/resize", handlePoolResize(pool))
|
||||
}
|
||||
|
||||
func handlePoolResize(pool *WorkerPool) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
s := r.FormValue("size")
|
||||
if s == "" {
|
||||
http.Error(w, "need size", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
size, err := strconv.Atoi(s)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
pool.Resize(context.Background(), size)
|
||||
fmt.Fprintln(w, "OK")
|
||||
}
|
||||
}
|
||||
|
||||
func handleRateSet(limiter *Limiter) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
s := r.FormValue("limit")
|
||||
if s == "" {
|
||||
http.Error(w, "need limit", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
name := r.FormValue("name")
|
||||
if name == "" {
|
||||
http.Error(w, "need name", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
limit, err := strconv.Atoi(s)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
limiter.SetLimit(context.Background(), name, rate.Limit(limit))
|
||||
fmt.Fprintln(w, "OK")
|
||||
}
|
||||
}
|
||||
|
||||
func handleRateSetAll(limiter *Limiter) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
s := r.FormValue("limit")
|
||||
if s == "" {
|
||||
http.Error(w, "need limit", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
limit, err := strconv.Atoi(s)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
limiter.SetAllLimits(context.Background(), rate.Limit(limit))
|
||||
fmt.Fprintln(w, "OK")
|
||||
}
|
||||
}
|
179
cmd/record-indexer/main.go
Normal file
179
cmd/record-indexer/main.go
Normal file
|
@ -0,0 +1,179 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
_ "github.com/joho/godotenv/autoload"
|
||||
"github.com/kelseyhightower/envconfig"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"github.com/rs/zerolog"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"github.com/uabluerail/indexer/util/gormzerolog"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
LogFile string
|
||||
LogFormat string `default:"text"`
|
||||
LogLevel int64 `default:"1"`
|
||||
MetricsPort string `split_words:"true"`
|
||||
DBUrl string `envconfig:"POSTGRES_URL"`
|
||||
Workers int `default:"2"`
|
||||
}
|
||||
|
||||
var config Config
|
||||
|
||||
func runMain(ctx context.Context) error {
|
||||
ctx = setupLogging(ctx)
|
||||
log := zerolog.Ctx(ctx)
|
||||
log.Debug().Msgf("Starting up...")
|
||||
db, err := gorm.Open(postgres.Open(config.DBUrl), &gorm.Config{
|
||||
Logger: gormzerolog.New(&logger.Config{
|
||||
SlowThreshold: 3 * time.Second,
|
||||
IgnoreRecordNotFoundError: true,
|
||||
}, nil),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("connecting to the database: %w", err)
|
||||
}
|
||||
log.Debug().Msgf("DB connection established")
|
||||
|
||||
limiter, err := NewLimiter(db)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create limiter: %w", err)
|
||||
}
|
||||
|
||||
ch := make(chan WorkItem)
|
||||
pool := NewWorkerPool(ch, db, config.Workers, limiter)
|
||||
if err := pool.Start(ctx); err != nil {
|
||||
return fmt.Errorf("failed to start worker pool: %w", err)
|
||||
}
|
||||
|
||||
scheduler := NewScheduler(ch, db)
|
||||
if err := scheduler.Start(ctx); err != nil {
|
||||
return fmt.Errorf("failed to start scheduler: %w", err)
|
||||
}
|
||||
|
||||
log.Info().Msgf("Starting HTTP listener on %q...", config.MetricsPort)
|
||||
AddAdminHandlers(limiter, pool)
|
||||
http.Handle("/metrics", promhttp.Handler())
|
||||
srv := &http.Server{Addr: fmt.Sprintf(":%s", config.MetricsPort)}
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
errCh <- srv.ListenAndServe()
|
||||
}()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if err := srv.Shutdown(context.Background()); err != nil {
|
||||
return fmt.Errorf("HTTP server shutdown failed: %w", err)
|
||||
}
|
||||
}
|
||||
return <-errCh
|
||||
}
|
||||
|
||||
func main() {
|
||||
flag.StringVar(&config.LogFile, "log", "", "Path to the log file. If empty, will log to stderr")
|
||||
flag.StringVar(&config.LogFormat, "log-format", "text", "Logging format. 'text' or 'json'")
|
||||
flag.Int64Var(&config.LogLevel, "log-level", 1, "Log level. -1 - trace, 0 - debug, 1 - info, 5 - panic")
|
||||
flag.IntVar(&config.Workers, "workers", 2, "Number of workers to start with")
|
||||
|
||||
if err := envconfig.Process("indexer", &config); err != nil {
|
||||
log.Fatalf("envconfig.Process: %s", err)
|
||||
}
|
||||
|
||||
flag.Parse()
|
||||
|
||||
ctx, _ := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
|
||||
if err := runMain(ctx); err != nil {
|
||||
fmt.Fprintln(os.Stderr, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func setupLogging(ctx context.Context) context.Context {
|
||||
logFile := os.Stderr
|
||||
|
||||
if config.LogFile != "" {
|
||||
f, err := os.OpenFile(config.LogFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to open the specified log file %q: %s", config.LogFile, err)
|
||||
}
|
||||
logFile = f
|
||||
}
|
||||
|
||||
var output io.Writer
|
||||
|
||||
switch config.LogFormat {
|
||||
case "json":
|
||||
output = logFile
|
||||
case "text":
|
||||
prefixList := []string{}
|
||||
info, ok := debug.ReadBuildInfo()
|
||||
if ok {
|
||||
prefixList = append(prefixList, info.Path+"/")
|
||||
}
|
||||
|
||||
basedir := ""
|
||||
_, sourceFile, _, ok := runtime.Caller(0)
|
||||
if ok {
|
||||
basedir = filepath.Dir(sourceFile)
|
||||
}
|
||||
|
||||
if basedir != "" && strings.HasPrefix(basedir, "/") {
|
||||
prefixList = append(prefixList, basedir+"/")
|
||||
head, _ := filepath.Split(basedir)
|
||||
for head != "/" {
|
||||
prefixList = append(prefixList, head)
|
||||
head, _ = filepath.Split(strings.TrimSuffix(head, "/"))
|
||||
}
|
||||
}
|
||||
|
||||
output = zerolog.ConsoleWriter{
|
||||
Out: logFile,
|
||||
NoColor: true,
|
||||
TimeFormat: time.RFC3339,
|
||||
PartsOrder: []string{
|
||||
zerolog.LevelFieldName,
|
||||
zerolog.TimestampFieldName,
|
||||
zerolog.CallerFieldName,
|
||||
zerolog.MessageFieldName,
|
||||
},
|
||||
FormatFieldName: func(i interface{}) string { return fmt.Sprintf("%s:", i) },
|
||||
FormatFieldValue: func(i interface{}) string { return fmt.Sprintf("%s", i) },
|
||||
FormatCaller: func(i interface{}) string {
|
||||
s := i.(string)
|
||||
for _, p := range prefixList {
|
||||
s = strings.TrimPrefix(s, p)
|
||||
}
|
||||
return s
|
||||
},
|
||||
}
|
||||
default:
|
||||
log.Fatalf("Invalid log format specified: %q", config.LogFormat)
|
||||
}
|
||||
|
||||
logger := zerolog.New(output).Level(zerolog.Level(config.LogLevel)).With().Caller().Timestamp().Logger()
|
||||
|
||||
ctx = logger.WithContext(ctx)
|
||||
|
||||
zerolog.DefaultContextLogger = &logger
|
||||
log.SetOutput(logger)
|
||||
|
||||
return ctx
|
||||
}
|
41
cmd/record-indexer/metrics.go
Normal file
41
cmd/record-indexer/metrics.go
Normal file
|
@ -0,0 +1,41 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||
)
|
||||
|
||||
var reposQueued = promauto.NewCounter(prometheus.CounterOpts{
|
||||
Name: "indexer_repos_queued_count",
|
||||
Help: "Number of repos added to the queue",
|
||||
})
|
||||
|
||||
var queueLenght = promauto.NewGaugeVec(prometheus.GaugeOpts{
|
||||
Name: "indexer_queue_length",
|
||||
Help: "Current length of indexing queue",
|
||||
}, []string{"state"})
|
||||
|
||||
var reposFetched = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "indexer_repos_fetched_count",
|
||||
Help: "Number of repos fetched",
|
||||
}, []string{"remote", "success"})
|
||||
|
||||
var reposIndexed = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "indexer_repos_indexed_count",
|
||||
Help: "Number of repos indexed",
|
||||
}, []string{"success"})
|
||||
|
||||
var recordsFetched = promauto.NewCounter(prometheus.CounterOpts{
|
||||
Name: "indexer_records_fetched_count",
|
||||
Help: "Number of records fetched",
|
||||
})
|
||||
|
||||
var recordsInserted = promauto.NewCounter(prometheus.CounterOpts{
|
||||
Name: "indexer_records_inserted_count",
|
||||
Help: "Number of records inserted into DB",
|
||||
})
|
||||
|
||||
var workerPoolSize = promauto.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "indexer_workers_count",
|
||||
Help: "Current number of workers running",
|
||||
})
|
82
cmd/record-indexer/ratelimits.go
Normal file
82
cmd/record-indexer/ratelimits.go
Normal file
|
@ -0,0 +1,82 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/uabluerail/indexer/pds"
|
||||
"golang.org/x/time/rate"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const defaultRateLimit = 10
|
||||
|
||||
type Limiter struct {
|
||||
mu sync.RWMutex
|
||||
db *gorm.DB
|
||||
limiter map[string]*rate.Limiter
|
||||
}
|
||||
|
||||
func NewLimiter(db *gorm.DB) (*Limiter, error) {
|
||||
remotes := []pds.PDS{}
|
||||
|
||||
if err := db.Find(&remotes).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to get the list of known PDSs: %w", err)
|
||||
}
|
||||
|
||||
l := &Limiter{
|
||||
db: db,
|
||||
limiter: map[string]*rate.Limiter{},
|
||||
}
|
||||
|
||||
for _, remote := range remotes {
|
||||
limit := remote.CrawlLimit
|
||||
if limit == 0 {
|
||||
limit = defaultRateLimit
|
||||
}
|
||||
l.limiter[remote.Host] = rate.NewLimiter(rate.Limit(limit), limit*2)
|
||||
}
|
||||
return l, nil
|
||||
}
|
||||
|
||||
func (l *Limiter) getLimiter(name string) *rate.Limiter {
|
||||
l.mu.RLock()
|
||||
limiter := l.limiter[name]
|
||||
l.mu.RUnlock()
|
||||
|
||||
if limiter != nil {
|
||||
return limiter
|
||||
}
|
||||
|
||||
limiter = rate.NewLimiter(defaultRateLimit, defaultRateLimit*2)
|
||||
l.mu.Lock()
|
||||
l.limiter[name] = limiter
|
||||
l.mu.Unlock()
|
||||
return limiter
|
||||
}
|
||||
|
||||
func (l *Limiter) Wait(ctx context.Context, name string) error {
|
||||
return l.getLimiter(name).Wait(ctx)
|
||||
}
|
||||
|
||||
func (l *Limiter) SetLimit(ctx context.Context, name string, limit rate.Limit) {
|
||||
l.getLimiter(name).SetLimit(limit)
|
||||
err := l.db.Model(&pds.PDS{}).Where(&pds.PDS{Host: name}).Updates(&pds.PDS{CrawlLimit: int(limit)}).Error
|
||||
if err != nil {
|
||||
zerolog.Ctx(ctx).Error().Err(err).Msgf("Failed to persist rate limit change for %q: %s", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Limiter) SetAllLimits(ctx context.Context, limit rate.Limit) {
|
||||
l.mu.RLock()
|
||||
for name, limiter := range l.limiter {
|
||||
limiter.SetLimit(limit)
|
||||
err := l.db.Model(&pds.PDS{}).Where(&pds.PDS{Host: name}).Updates(&pds.PDS{CrawlLimit: int(limit)}).Error
|
||||
if err != nil {
|
||||
zerolog.Ctx(ctx).Error().Err(err).Msgf("Failed to persist rate limit change for %q: %s", name, err)
|
||||
}
|
||||
}
|
||||
l.mu.RUnlock()
|
||||
}
|
136
cmd/record-indexer/scheduler.go
Normal file
136
cmd/record-indexer/scheduler.go
Normal file
|
@ -0,0 +1,136 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/uabluerail/indexer/pds"
|
||||
"github.com/uabluerail/indexer/repo"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Scheduler struct {
|
||||
db *gorm.DB
|
||||
output chan<- WorkItem
|
||||
|
||||
queue map[string]*repo.Repo
|
||||
inProgress map[string]*repo.Repo
|
||||
}
|
||||
|
||||
func NewScheduler(output chan<- WorkItem, db *gorm.DB) *Scheduler {
|
||||
return &Scheduler{
|
||||
db: db,
|
||||
output: output,
|
||||
queue: map[string]*repo.Repo{},
|
||||
inProgress: map[string]*repo.Repo{},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Scheduler) Start(ctx context.Context) error {
|
||||
go s.run(ctx)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Scheduler) run(ctx context.Context) {
|
||||
log := zerolog.Ctx(ctx)
|
||||
t := time.NewTicker(time.Minute)
|
||||
defer t.Stop()
|
||||
|
||||
if err := s.fillQueue(ctx); err != nil {
|
||||
log.Error().Err(err).Msgf("Failed to get more tasks for the queue: %s", err)
|
||||
}
|
||||
|
||||
done := make(chan string)
|
||||
for {
|
||||
if len(s.queue) > 0 {
|
||||
next := WorkItem{signal: make(chan struct{})}
|
||||
for _, r := range s.queue {
|
||||
next.Repo = r
|
||||
break
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-t.C:
|
||||
if err := s.fillQueue(ctx); err != nil {
|
||||
log.Error().Err(err).Msgf("Failed to get more tasks for the queue: %s", err)
|
||||
}
|
||||
case s.output <- next:
|
||||
delete(s.queue, next.Repo.DID)
|
||||
s.inProgress[next.Repo.DID] = next.Repo
|
||||
go func(did string, ch chan struct{}) {
|
||||
select {
|
||||
case <-ch:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
done <- did
|
||||
}(next.Repo.DID, next.signal)
|
||||
s.updateQueueLenMetrics()
|
||||
case did := <-done:
|
||||
delete(s.inProgress, did)
|
||||
s.updateQueueLenMetrics()
|
||||
}
|
||||
} else {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-t.C:
|
||||
if err := s.fillQueue(ctx); err != nil {
|
||||
log.Error().Err(err).Msgf("Failed to get more tasks for the queue: %s", err)
|
||||
}
|
||||
case did := <-done:
|
||||
delete(s.inProgress, did)
|
||||
s.updateQueueLenMetrics()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Scheduler) fillQueue(ctx context.Context) error {
|
||||
const maxQueueLen = 10000
|
||||
|
||||
if len(s.queue)+len(s.inProgress) >= maxQueueLen {
|
||||
return nil
|
||||
}
|
||||
|
||||
remotes := []pds.PDS{}
|
||||
if err := s.db.Find(&remotes).Error; err != nil {
|
||||
return fmt.Errorf("failed to get the list of PDSs: %w", err)
|
||||
}
|
||||
|
||||
perPDSLimit := maxQueueLen * 2 / len(remotes)
|
||||
|
||||
for _, remote := range remotes {
|
||||
repos := []repo.Repo{}
|
||||
|
||||
err := s.db.Model(&repos).
|
||||
Where(`pds = ? AND (
|
||||
last_indexed_rev is null OR last_indexed_rev = ''
|
||||
OR (first_rev_since_reset is not null AND first_rev_since_reset <> '' AND last_indexed_rev < first_rev_since_reset))`,
|
||||
remote.ID).
|
||||
Limit(perPDSLimit).Find(&repos).Error
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("querying DB: %w", err)
|
||||
}
|
||||
for _, r := range repos {
|
||||
if s.queue[r.DID] != nil || s.inProgress[r.DID] != nil {
|
||||
continue
|
||||
}
|
||||
copied := r
|
||||
s.queue[r.DID] = &copied
|
||||
reposQueued.Inc()
|
||||
}
|
||||
s.updateQueueLenMetrics()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Scheduler) updateQueueLenMetrics() {
|
||||
queueLenght.WithLabelValues("queued").Set(float64(len(s.queue)))
|
||||
queueLenght.WithLabelValues("inProgress").Set(float64(len(s.inProgress)))
|
||||
}
|
238
cmd/record-indexer/workerpool.go
Normal file
238
cmd/record-indexer/workerpool.go
Normal file
|
@ -0,0 +1,238 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/imax9000/errors"
|
||||
"github.com/rs/zerolog"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
|
||||
comatproto "github.com/bluesky-social/indigo/api/atproto"
|
||||
"github.com/bluesky-social/indigo/xrpc"
|
||||
|
||||
"github.com/uabluerail/bsky-tools/xrpcauth"
|
||||
"github.com/uabluerail/indexer/models"
|
||||
"github.com/uabluerail/indexer/repo"
|
||||
"github.com/uabluerail/indexer/util/resolver"
|
||||
)
|
||||
|
||||
type WorkItem struct {
|
||||
Repo *repo.Repo
|
||||
signal chan struct{}
|
||||
}
|
||||
|
||||
type WorkerPool struct {
|
||||
db *gorm.DB
|
||||
input <-chan WorkItem
|
||||
limiter *Limiter
|
||||
|
||||
workerSignals []chan struct{}
|
||||
resize chan int
|
||||
}
|
||||
|
||||
func NewWorkerPool(input <-chan WorkItem, db *gorm.DB, size int, limiter *Limiter) *WorkerPool {
|
||||
r := &WorkerPool{
|
||||
db: db,
|
||||
input: input,
|
||||
limiter: limiter,
|
||||
resize: make(chan int),
|
||||
}
|
||||
r.workerSignals = make([]chan struct{}, size)
|
||||
for i := range r.workerSignals {
|
||||
r.workerSignals[i] = make(chan struct{})
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func (p *WorkerPool) Start(ctx context.Context) error {
|
||||
go p.run(ctx)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *WorkerPool) Resize(ctx context.Context, size int) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case p.resize <- size:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (p *WorkerPool) run(ctx context.Context) {
|
||||
for _, ch := range p.workerSignals {
|
||||
go p.worker(ctx, ch)
|
||||
}
|
||||
workerPoolSize.Set(float64(len(p.workerSignals)))
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
for _, ch := range p.workerSignals {
|
||||
close(ch)
|
||||
}
|
||||
// also wait for all workers to stop?
|
||||
case newSize := <-p.resize:
|
||||
switch {
|
||||
case newSize > len(p.workerSignals):
|
||||
ch := make([]chan struct{}, newSize-len(p.workerSignals))
|
||||
for i := range ch {
|
||||
ch[i] = make(chan struct{})
|
||||
go p.worker(ctx, ch[i])
|
||||
}
|
||||
p.workerSignals = append(p.workerSignals, ch...)
|
||||
workerPoolSize.Set(float64(len(p.workerSignals)))
|
||||
case newSize < len(p.workerSignals) && newSize > 0:
|
||||
for _, ch := range p.workerSignals[newSize:] {
|
||||
close(ch)
|
||||
}
|
||||
p.workerSignals = p.workerSignals[:newSize]
|
||||
workerPoolSize.Set(float64(len(p.workerSignals)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *WorkerPool) worker(ctx context.Context, signal chan struct{}) {
|
||||
log := zerolog.Ctx(ctx)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-signal:
|
||||
return
|
||||
case work := <-p.input:
|
||||
updates := &repo.Repo{}
|
||||
if err := p.doWork(ctx, work); err != nil {
|
||||
log.Error().Err(err).Msgf("Work task %q failed: %s", work.Repo.DID, err)
|
||||
updates.LastError = err.Error()
|
||||
reposIndexed.WithLabelValues("false").Inc()
|
||||
} else {
|
||||
reposIndexed.WithLabelValues("true").Inc()
|
||||
}
|
||||
updates.LastIndexAttempt = time.Now()
|
||||
err := p.db.Model(&repo.Repo{}).
|
||||
Where(&repo.Repo{Model: gorm.Model{ID: work.Repo.ID}}).
|
||||
Updates(updates).Error
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("Failed to update repo info for %q: %s", work.Repo.DID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *WorkerPool) doWork(ctx context.Context, work WorkItem) error {
|
||||
log := zerolog.Ctx(ctx)
|
||||
defer close(work.signal)
|
||||
|
||||
doc, err := resolver.GetDocument(ctx, work.Repo.DID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("resolving did %q: %w", work.Repo.DID, err)
|
||||
}
|
||||
|
||||
pdsHost := ""
|
||||
for _, srv := range doc.Service {
|
||||
if srv.Type != "AtprotoPersonalDataServer" {
|
||||
continue
|
||||
}
|
||||
pdsHost = srv.ServiceEndpoint
|
||||
}
|
||||
if pdsHost == "" {
|
||||
return fmt.Errorf("did not find any PDS in DID Document")
|
||||
}
|
||||
u, err := url.Parse(pdsHost)
|
||||
if err != nil {
|
||||
return fmt.Errorf("PDS endpoint (%q) is an invalid URL: %w", pdsHost, err)
|
||||
}
|
||||
if u.Host == "" {
|
||||
return fmt.Errorf("PDS endpoint (%q) doesn't have a host part", pdsHost)
|
||||
}
|
||||
|
||||
client := xrpcauth.NewAnonymousClient(ctx)
|
||||
client.Host = u.String()
|
||||
|
||||
retry:
|
||||
if p.limiter != nil {
|
||||
if err := p.limiter.Wait(ctx, u.String()); err != nil {
|
||||
return fmt.Errorf("failed to wait on rate limiter: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
b, err := comatproto.SyncGetRepo(ctx, client, work.Repo.DID, "")
|
||||
if err != nil {
|
||||
if err, ok := errors.As[*xrpc.Error](err); ok {
|
||||
if err.IsThrottled() && err.Ratelimit != nil {
|
||||
log.Debug().Str("pds", u.String()).Msgf("Hit a rate limit (%s), sleeping until %s", err.Ratelimit.Policy, err.Ratelimit.Reset)
|
||||
time.Sleep(time.Until(err.Ratelimit.Reset))
|
||||
goto retry
|
||||
}
|
||||
}
|
||||
|
||||
reposFetched.WithLabelValues(u.String(), "false").Inc()
|
||||
return fmt.Errorf("failed to fetch repo: %w", err)
|
||||
}
|
||||
reposFetched.WithLabelValues(u.String(), "true").Inc()
|
||||
|
||||
newRev, err := repo.GetRev(ctx, bytes.NewReader(b))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read 'rev' from the fetched repo: %w", err)
|
||||
}
|
||||
|
||||
newRecs, err := repo.ExtractRecords(ctx, bytes.NewReader(b))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to extract records: %w", err)
|
||||
}
|
||||
recs := []repo.Record{}
|
||||
for k, v := range newRecs {
|
||||
parts := strings.SplitN(k, "/", 2)
|
||||
if len(parts) != 2 {
|
||||
log.Warn().Msgf("Unexpected key format: %q", k)
|
||||
continue
|
||||
}
|
||||
recs = append(recs, repo.Record{
|
||||
Repo: models.ID(work.Repo.ID),
|
||||
Collection: parts[0],
|
||||
Rkey: parts[1],
|
||||
Content: v,
|
||||
})
|
||||
}
|
||||
recordsFetched.Add(float64(len(recs)))
|
||||
if len(recs) > 0 {
|
||||
for _, batch := range splitInBatshes(recs, 500) {
|
||||
result := p.db.Model(&repo.Record{}).
|
||||
Clauses(clause.OnConflict{DoUpdates: clause.AssignmentColumns([]string{"content"}),
|
||||
Columns: []clause.Column{{Name: "repo"}, {Name: "collection"}, {Name: "rkey"}}}).
|
||||
Create(batch)
|
||||
if err := result.Error; err != nil {
|
||||
return fmt.Errorf("inserting records into the database: %w", err)
|
||||
}
|
||||
recordsInserted.Add(float64(result.RowsAffected))
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
err = p.db.Model(&repo.Repo{}).Where(&repo.Repo{Model: gorm.Model{ID: work.Repo.ID}}).
|
||||
Updates(&repo.Repo{LastIndexedRev: newRev}).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("updating repo rev: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func splitInBatshes[T any](s []T, batchSize int) [][]T {
|
||||
var r [][]T
|
||||
for i := 0; i < len(s); i += batchSize {
|
||||
if i+batchSize < len(s) {
|
||||
r = append(r, s[i:i+batchSize])
|
||||
} else {
|
||||
r = append(r, s[i:])
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue