Remove everything that's not needed for PLC mirror

This commit is contained in:
Max Ignatenko 2024-09-25 22:00:00 +01:00
parent 56727bbe11
commit 9b877c1524
89 changed files with 12 additions and 38231 deletions

View file

@ -1,14 +0,0 @@
FROM golang:1.22.3 as builder
WORKDIR /app
COPY go.mod go.sum ./
RUN go mod download
COPY . ./
RUN go build -trimpath ./cmd/consumer
FROM alpine:latest as certs
RUN apk --update add ca-certificates
FROM debian:stable-slim
COPY --from=certs /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ca-certificates.crt
COPY --from=builder /app/consumer .
ENTRYPOINT ["./consumer"]

View file

@ -1,7 +0,0 @@
*
**/*
!go.mod
!go.sum
!**/*.go
cmd/**
!cmd/consumer

View file

@ -1,622 +0,0 @@
package main
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"math"
"net/http"
"net/url"
"path"
"strings"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/gorilla/websocket"
"github.com/prometheus/client_golang/prometheus"
"github.com/rs/zerolog"
"gorm.io/gorm"
"gorm.io/gorm/clause"
comatproto "github.com/bluesky-social/indigo/api/atproto"
"github.com/bluesky-social/indigo/xrpc"
"github.com/ipld/go-ipld-prime/codec/dagcbor"
"github.com/ipld/go-ipld-prime/datamodel"
"github.com/ipld/go-ipld-prime/node/basicnode"
"github.com/uabluerail/indexer/models"
"github.com/uabluerail/indexer/pds"
"github.com/uabluerail/indexer/repo"
"github.com/uabluerail/indexer/util/fix"
"github.com/uabluerail/indexer/util/resolver"
)
const lastRevUpdateInterval = 24 * time.Hour
type BadRecord struct {
ID models.ID `gorm:"primarykey"`
CreatedAt time.Time
PDS models.ID `gorm:"index"`
Cursor int64
Error string
Content []byte
}
type Consumer struct {
db *gorm.DB
remote pds.PDS
running chan struct{}
lastCursorPersist time.Time
}
func NewConsumer(ctx context.Context, remote *pds.PDS, db *gorm.DB) (*Consumer, error) {
if err := db.AutoMigrate(&BadRecord{}); err != nil {
return nil, fmt.Errorf("db.AutoMigrate: %s", err)
}
return &Consumer{
db: db,
remote: *remote,
running: make(chan struct{}),
}, nil
}
func (c *Consumer) Start(ctx context.Context) error {
go c.run(ctx)
return nil
}
func (c *Consumer) Wait(ctx context.Context) error {
select {
case <-ctx.Done():
return ctx.Err()
case <-c.running:
// Channel got closed
return nil
}
}
func (c *Consumer) run(ctx context.Context) {
log := zerolog.Ctx(ctx).With().Str("pds", c.remote.Host).Logger()
ctx = log.WithContext(ctx)
backoffTimer := backoff.NewExponentialBackOff(
backoff.WithMaxElapsedTime(0),
backoff.WithInitialInterval(time.Second),
backoff.WithMaxInterval(5*time.Minute),
)
pdsOnline.WithLabelValues(c.remote.Host).Set(0)
defer close(c.running)
for {
select {
case <-c.running:
log.Error().Msgf("Attempt to start previously stopped consumer")
return
case <-ctx.Done():
log.Info().Msgf("Consumer stopped")
lastEventTimestamp.DeletePartialMatch(prometheus.Labels{"remote": c.remote.Host})
eventCounter.DeletePartialMatch(prometheus.Labels{"remote": c.remote.Host})
reposDiscovered.DeletePartialMatch(prometheus.Labels{"remote": c.remote.Host})
postsByLanguageIndexed.DeletePartialMatch(prometheus.Labels{"remote": c.remote.Host})
pdsOnline.DeletePartialMatch(prometheus.Labels{"remote": c.remote.Host})
return
default:
start := time.Now()
if err := c.runOnce(ctx); err != nil {
log.Error().Err(err).Msgf("Consumer of %q failed (will be restarted): %s", c.remote.Host, err)
connectionFailures.WithLabelValues(c.remote.Host).Inc()
}
if time.Since(start) > backoffTimer.MaxInterval*3 {
// XXX: assume that c.runOnce did some useful work in this case,
// even though it might have been stuck on some absurdly long timeouts.
backoffTimer.Reset()
}
time.Sleep(backoffTimer.NextBackOff())
}
}
}
func (c *Consumer) runOnce(ctx context.Context) error {
log := zerolog.Ctx(ctx)
log.Info().
Int64("cursor", c.remote.Cursor).
Int64("first_cursor_since_reset", c.remote.FirstCursorSinceReset).
Msgf("Connecting to firehose of %s...", c.remote.Host)
addr, err := url.Parse(c.remote.Host)
if err != nil {
return fmt.Errorf("parsing URL %q: %s", c.remote.Host, err)
}
addr.Scheme = "wss"
addr.Path = path.Join(addr.Path, "xrpc/com.atproto.sync.subscribeRepos")
if c.remote.Cursor > 0 {
params := url.Values{"cursor": []string{fmt.Sprint(c.remote.Cursor)}}
addr.RawQuery = params.Encode()
}
conn, _, err := websocket.DefaultDialer.DialContext(ctx, addr.String(), http.Header{})
if err != nil {
return fmt.Errorf("establishing websocker connection: %w", err)
}
defer conn.Close()
pdsOnline.WithLabelValues(c.remote.Host).Set(1)
defer func() { pdsOnline.WithLabelValues(c.remote.Host).Set(0) }()
ch := make(chan bool)
defer close(ch)
go func() {
t := time.NewTicker(time.Minute)
defer t.Stop()
for {
select {
case <-ch:
return
case <-t.C:
if err := conn.WriteControl(websocket.PingMessage, []byte("ping"), time.Now().Add(time.Minute)); err != nil {
log.Error().Err(err).Msgf("Failed to send ping: %s", err)
}
}
}
}()
first := true
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
_, b, err := conn.ReadMessage()
if err != nil {
return fmt.Errorf("websocket.ReadMessage: %w", err)
}
r := bytes.NewReader(b)
proto := basicnode.Prototype.Any
headerNode := proto.NewBuilder()
if err := (&dagcbor.DecodeOptions{DontParseBeyondEnd: true}).Decode(headerNode, r); err != nil {
return fmt.Errorf("unmarshaling message header: %w", err)
}
header, err := parseHeader(headerNode.Build())
if err != nil {
return fmt.Errorf("parsing message header: %w", err)
}
switch header.Op {
case 1:
if err := c.processMessage(ctx, header.Type, r, first); err != nil {
if ctx.Err() != nil {
// We're shutting down, so the error is most likely due to that.
return err
}
const maxBadRecords = 500
var count int64
if err2 := c.db.Model(&BadRecord{}).Where(&BadRecord{PDS: c.remote.ID}).Count(&count).Error; err2 != nil {
return err
}
if count >= maxBadRecords {
return err
}
log.Error().Err(err).Str("pds", c.remote.Host).Msgf("Failed to process message at cursor %d: %s", c.remote.Cursor, err)
err := c.db.Create(&BadRecord{
PDS: c.remote.ID,
Cursor: c.remote.Cursor,
Error: err.Error(),
Content: b,
}).Error
if err != nil {
return fmt.Errorf("failed to store bad message: %s", err)
}
}
case -1:
bodyNode := proto.NewBuilder()
if err := (&dagcbor.DecodeOptions{DontParseBeyondEnd: true, AllowLinks: true}).Decode(bodyNode, r); err != nil {
return fmt.Errorf("unmarshaling message body: %w", err)
}
body, err := parseError(bodyNode.Build())
if err != nil {
return fmt.Errorf("parsing error payload: %w", err)
}
return &body
default:
log.Warn().Msgf("Unknown 'op' value received: %d", header.Op)
}
first = false
}
}
}
func (c *Consumer) resetCursor(ctx context.Context, seq int64) error {
zerolog.Ctx(ctx).Warn().Str("pds", c.remote.Host).Msgf("Cursor reset: %d -> %d", c.remote.Cursor, seq)
err := c.db.Model(&c.remote).
Where(&pds.PDS{ID: c.remote.ID}).
Updates(&pds.PDS{FirstCursorSinceReset: seq}).Error
if err != nil {
return fmt.Errorf("updating FirstCursorSinceReset: %w", err)
}
c.remote.FirstCursorSinceReset = seq
return nil
}
func (c *Consumer) updateCursor(ctx context.Context, seq int64) error {
if math.Abs(float64(seq-c.remote.Cursor)) < 100 && time.Since(c.lastCursorPersist) < 5*time.Second {
c.remote.Cursor = seq
return nil
}
err := c.db.Model(&c.remote).
Where(&pds.PDS{ID: c.remote.ID}).
Updates(&pds.PDS{Cursor: seq}).Error
if err != nil {
return fmt.Errorf("updating Cursor: %w", err)
}
c.remote.Cursor = seq
return nil
}
func (c *Consumer) processMessage(ctx context.Context, typ string, r io.Reader, first bool) error {
log := zerolog.Ctx(ctx)
eventCounter.WithLabelValues(c.remote.Host, typ).Inc()
switch typ {
case "#commit":
payload := &comatproto.SyncSubscribeRepos_Commit{}
if err := payload.UnmarshalCBOR(r); err != nil {
return fmt.Errorf("failed to unmarshal commit: %w", err)
}
exportEventTimestamp(ctx, c.remote.Host, payload.Time)
if c.remote.FirstCursorSinceReset == 0 {
if err := c.resetCursor(ctx, payload.Seq); err != nil {
return fmt.Errorf("handling cursor reset: %w", err)
}
}
repoInfo, created, err := repo.EnsureExists(ctx, c.db, payload.Repo)
if err != nil {
return fmt.Errorf("repo.EnsureExists(%q): %w", payload.Repo, err)
}
if repoInfo.LastKnownKey == "" {
_, pubKey, err := resolver.GetPDSEndpointAndPublicKey(ctx, payload.Repo)
if err != nil {
return fmt.Errorf("failed to get DID doc for %q: %w", payload.Repo, err)
}
repoInfo.LastKnownKey = pubKey
err = c.db.Model(repoInfo).Where(&repo.Repo{ID: repoInfo.ID}).Updates(&repo.Repo{LastKnownKey: pubKey}).Error
if err != nil {
return fmt.Errorf("failed to update the key for %q: %w", payload.Repo, err)
}
}
if repoInfo.PDS != c.remote.ID {
u, _, err := resolver.GetPDSEndpointAndPublicKey(ctx, payload.Repo)
if err == nil {
cur, err := pds.EnsureExists(ctx, c.db, u.String())
if err == nil {
if repoInfo.PDS != cur.ID {
// Repo was migrated, lets update our record.
err := c.db.Model(repoInfo).Where(&repo.Repo{ID: repoInfo.ID}).Updates(&repo.Repo{PDS: cur.ID}).Error
if err != nil {
log.Error().Err(err).Msgf("Repo %q was migrated to %q, but updating the repo has failed: %s", payload.Repo, cur.Host, err)
}
}
repoInfo.PDS = cur.ID
} else {
log.Error().Err(err).Msgf("Failed to get PDS record for %q: %s", u, err)
}
} else {
log.Error().Err(err).Msgf("Failed to get PDS endpoint for repo %q: %s", payload.Repo, err)
}
if repoInfo.PDS != c.remote.ID {
// We checked a recent version of DID doc and this is still not a correct PDS.
log.Error().Str("did", payload.Repo).Str("rev", payload.Rev).
Msgf("Commit from an incorrect PDS, skipping")
return nil
}
}
if created {
reposDiscovered.WithLabelValues(c.remote.Host).Inc()
}
expectRecords := false
deletions := []string{}
for _, op := range payload.Ops {
switch op.Action {
case "create":
expectRecords = true
case "update":
expectRecords = true
case "delete":
deletions = append(deletions, op.Path)
}
}
for _, d := range deletions {
parts := strings.SplitN(d, "/", 2)
if len(parts) != 2 {
continue
}
err := c.db.Model(&repo.Record{}).
Where(&repo.Record{
Repo: models.ID(repoInfo.ID),
Collection: parts[0],
Rkey: parts[1]}).
Updates(&repo.Record{Deleted: true}).Error
if err != nil {
return fmt.Errorf("failed to mark %s/%s as deleted: %w", payload.Repo, d, err)
}
}
newRecs, err := repo.ExtractRecords(ctx, bytes.NewReader(payload.Blocks), repoInfo.LastKnownKey)
if errors.Is(err, repo.ErrInvalidSignature) {
// Key might have been updated recently.
_, pubKey, err2 := resolver.GetPDSEndpointAndPublicKey(ctx, payload.Repo)
if err2 != nil {
return fmt.Errorf("failed to get DID doc for %q: %w", payload.Repo, err2)
}
if repoInfo.LastKnownKey != pubKey {
repoInfo.LastKnownKey = pubKey
err2 = c.db.Model(repoInfo).Where(&repo.Repo{ID: repoInfo.ID}).Updates(&repo.Repo{LastKnownKey: pubKey}).Error
if err2 != nil {
return fmt.Errorf("failed to update the key for %q: %w", payload.Repo, err2)
}
// Retry with the new key.
newRecs, err = repo.ExtractRecords(ctx, bytes.NewReader(payload.Blocks), pubKey)
}
}
if err != nil {
return fmt.Errorf("failed to extract records: %w", err)
}
recs := []repo.Record{}
for k, v := range newRecs {
parts := strings.SplitN(k, "/", 2)
if len(parts) != 2 {
log.Warn().Msgf("Unexpected key format: %q", k)
continue
}
langs, _, err := repo.GetLang(ctx, v)
if err == nil {
for _, lang := range langs {
postsByLanguageIndexed.WithLabelValues(c.remote.Host, lang).Inc()
}
}
recs = append(recs, repo.Record{
Repo: models.ID(repoInfo.ID),
Collection: parts[0],
Rkey: parts[1],
// XXX: proper replacement of \u0000 would require full parsing of JSON
// and recursive iteration over all string values, but this
// should work well enough for now.
Content: fix.EscapeNullCharForPostgres(v),
AtRev: payload.Rev,
})
}
if len(recs) == 0 && expectRecords {
log.Debug().Int64("seq", payload.Seq).Str("pds", c.remote.Host).Msgf("len(recs) == 0")
}
if len(recs) > 0 {
err = c.db.Model(&repo.Record{}).
Clauses(clause.OnConflict{
Where: clause.Where{Exprs: []clause.Expression{clause.Or(
clause.Eq{Column: clause.Column{Name: "at_rev", Table: "records"}, Value: nil},
clause.Eq{Column: clause.Column{Name: "at_rev", Table: "records"}, Value: ""},
clause.Lt{
Column: clause.Column{Name: "at_rev", Table: "records"},
Value: clause.Column{Name: "at_rev", Table: "excluded"}},
)}},
DoUpdates: clause.AssignmentColumns([]string{"content", "at_rev"}),
Columns: []clause.Column{{Name: "repo"}, {Name: "collection"}, {Name: "rkey"}}}).
Create(recs).Error
if err != nil {
return fmt.Errorf("inserting records into the database: %w", err)
}
}
if repoInfo.FirstCursorSinceReset > 0 && repoInfo.FirstRevSinceReset != "" &&
repoInfo.LastIndexedRev != "" &&
c.remote.FirstCursorSinceReset > 0 &&
repoInfo.FirstCursorSinceReset >= c.remote.FirstCursorSinceReset &&
repoInfo.FirstRevSinceReset <= repoInfo.LastIndexedRev &&
time.Since(repoInfo.UpdatedAt) > lastRevUpdateInterval {
err = c.db.Model(&repo.Repo{}).Where(&repo.Repo{ID: repoInfo.ID}).
Updates(&repo.Repo{
LastFirehoseRev: payload.Rev,
}).Error
if err != nil {
log.Error().Err(err).Msgf("Failed to update last_firehose_rev for %q: %s", repoInfo.DID, err)
}
}
if payload.TooBig {
// Just trigger a re-index by resetting rev.
err := c.db.Model(&repo.Repo{}).Where(&repo.Repo{ID: repoInfo.ID}).
Updates(&repo.Repo{
FirstCursorSinceReset: c.remote.FirstCursorSinceReset,
FirstRevSinceReset: payload.Rev,
}).Error
if err != nil {
return fmt.Errorf("failed to update repo info after cursor reset: %w", err)
}
}
if repoInfo.FirstCursorSinceReset != c.remote.FirstCursorSinceReset {
err := c.db.Model(&repo.Repo{}).Debug().Where(&repo.Repo{ID: repoInfo.ID}).
Updates(&repo.Repo{
FirstCursorSinceReset: c.remote.FirstCursorSinceReset,
FirstRevSinceReset: payload.Rev,
}).Error
if err != nil {
return fmt.Errorf("failed to update repo info after cursor reset: %w", err)
}
}
if err := c.updateCursor(ctx, payload.Seq); err != nil {
return err
}
case "#handle":
payload := &comatproto.SyncSubscribeRepos_Handle{}
if err := payload.UnmarshalCBOR(r); err != nil {
return fmt.Errorf("failed to unmarshal commit: %w", err)
}
exportEventTimestamp(ctx, c.remote.Host, payload.Time)
if c.remote.FirstCursorSinceReset == 0 {
if err := c.resetCursor(ctx, payload.Seq); err != nil {
return fmt.Errorf("handling cursor reset: %w", err)
}
}
// No-op, we don't store handles.
if err := c.updateCursor(ctx, payload.Seq); err != nil {
return err
}
case "#migrate":
payload := &comatproto.SyncSubscribeRepos_Migrate{}
if err := payload.UnmarshalCBOR(r); err != nil {
return fmt.Errorf("failed to unmarshal commit: %w", err)
}
exportEventTimestamp(ctx, c.remote.Host, payload.Time)
if c.remote.FirstCursorSinceReset == 0 {
if err := c.resetCursor(ctx, payload.Seq); err != nil {
return fmt.Errorf("handling cursor reset: %w", err)
}
}
log.Debug().Interface("payload", payload).Str("did", payload.Did).Msgf("MIGRATION")
// TODO
if err := c.updateCursor(ctx, payload.Seq); err != nil {
return err
}
case "#tombstone":
payload := &comatproto.SyncSubscribeRepos_Tombstone{}
if err := payload.UnmarshalCBOR(r); err != nil {
return fmt.Errorf("failed to unmarshal commit: %w", err)
}
exportEventTimestamp(ctx, c.remote.Host, payload.Time)
if c.remote.FirstCursorSinceReset == 0 {
if err := c.resetCursor(ctx, payload.Seq); err != nil {
return fmt.Errorf("handling cursor reset: %w", err)
}
}
// TODO
if err := c.updateCursor(ctx, payload.Seq); err != nil {
return err
}
case "#info":
payload := &comatproto.SyncSubscribeRepos_Info{}
if err := payload.UnmarshalCBOR(r); err != nil {
return fmt.Errorf("failed to unmarshal commit: %w", err)
}
switch payload.Name {
case "OutdatedCursor":
if !first {
log.Warn().Msgf("Received cursor reset notification in the middle of a stream: %+v", payload)
}
c.remote.FirstCursorSinceReset = 0
default:
log.Error().Msgf("Unknown #info message %q: %+v", payload.Name, payload)
}
case "#identity":
payload := &comatproto.SyncSubscribeRepos_Identity{}
if err := payload.UnmarshalCBOR(r); err != nil {
return fmt.Errorf("failed to unmarshal commit: %w", err)
}
exportEventTimestamp(ctx, c.remote.Host, payload.Time)
log.Trace().Str("did", payload.Did).Str("type", typ).Int64("seq", payload.Seq).
Msgf("#identity message: %s seq=%d time=%q", payload.Did, payload.Seq, payload.Time)
resolver.Resolver.FlushCacheFor(payload.Did)
// TODO: fetch DID doc and update PDS field?
default:
b, err := io.ReadAll(r)
if err != nil {
log.Error().Err(err).Msgf("Failed to read message payload: %s", err)
}
log.Warn().Msgf("Unknown message type received: %s payload=%q", typ, string(b))
}
return nil
}
type Header struct {
Op int64
Type string
}
func parseHeader(node datamodel.Node) (Header, error) {
r := Header{}
op, err := node.LookupByString("op")
if err != nil {
return r, fmt.Errorf("missing 'op': %w", err)
}
r.Op, err = op.AsInt()
if err != nil {
return r, fmt.Errorf("op.AsInt(): %w", err)
}
if r.Op == -1 {
// Error frame, type should not be present
return r, nil
}
t, err := node.LookupByString("t")
if err != nil {
return r, fmt.Errorf("missing 't': %w", err)
}
r.Type, err = t.AsString()
if err != nil {
return r, fmt.Errorf("t.AsString(): %w", err)
}
return r, nil
}
func parseError(node datamodel.Node) (xrpc.XRPCError, error) {
r := xrpc.XRPCError{}
e, err := node.LookupByString("error")
if err != nil {
return r, fmt.Errorf("missing 'error': %w", err)
}
r.ErrStr, err = e.AsString()
if err != nil {
return r, fmt.Errorf("error.AsString(): %w", err)
}
m, err := node.LookupByString("message")
if err == nil {
r.Message, err = m.AsString()
if err != nil {
return r, fmt.Errorf("message.AsString(): %w", err)
}
} else if !errors.Is(err, datamodel.ErrNotExists{}) {
return r, fmt.Errorf("looking up 'message': %w", err)
}
return r, nil
}
func exportEventTimestamp(ctx context.Context, remote string, timestamp string) {
if t, err := time.Parse(time.RFC3339, timestamp); err != nil {
zerolog.Ctx(ctx).Error().Err(err).Str("pds", remote).Msgf("Failed to parse %q as a timestamp: %s", timestamp, err)
} else {
lastEventTimestamp.WithLabelValues(remote).Set(float64(t.Unix()))
}
}

View file

@ -1,268 +0,0 @@
package main
import (
"context"
"flag"
"fmt"
"io"
"log"
"net/http"
_ "net/http/pprof"
"os"
"os/signal"
"path/filepath"
"runtime"
"runtime/debug"
"strings"
"sync"
"syscall"
"time"
_ "github.com/joho/godotenv/autoload"
"github.com/kelseyhightower/envconfig"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/rs/zerolog"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"github.com/uabluerail/indexer/pds"
"github.com/uabluerail/indexer/util/gormzerolog"
)
type Config struct {
LogFile string
LogFormat string `default:"text"`
LogLevel int64 `default:"1"`
MetricsPort string `split_words:"true"`
DBUrl string `envconfig:"POSTGRES_URL"`
Relays string
}
var config Config
func runMain(ctx context.Context) error {
ctx = setupLogging(ctx)
log := zerolog.Ctx(ctx)
log.Debug().Msgf("Starting up...")
db, err := gorm.Open(postgres.Open(config.DBUrl), &gorm.Config{
Logger: gormzerolog.New(&logger.Config{
SlowThreshold: 3 * time.Second,
IgnoreRecordNotFoundError: true,
}, nil),
})
if err != nil {
return fmt.Errorf("connecting to the database: %w", err)
}
log.Debug().Msgf("DB connection established")
if config.Relays != "" {
for _, host := range strings.Split(config.Relays, ",") {
c, err := NewRelayConsumer(ctx, host, db)
if err != nil {
log.Error().Err(err).Msgf("Failed to create relay consumer for %q: %s", host, err)
}
c.Start(ctx)
}
}
consumersCh := make(chan struct{})
go runConsumers(ctx, db, consumersCh)
log.Info().Msgf("Starting HTTP listener on %q...", config.MetricsPort)
http.Handle("/metrics", promhttp.Handler())
srv := &http.Server{Addr: fmt.Sprintf(":%s", config.MetricsPort)}
errCh := make(chan error)
go func() {
errCh <- srv.ListenAndServe()
}()
select {
case <-ctx.Done():
if err := srv.Shutdown(context.Background()); err != nil {
return fmt.Errorf("HTTP server shutdown failed: %w", err)
}
}
log.Info().Msgf("Waiting for consumers to stop...")
<-consumersCh
return <-errCh
}
func runConsumers(ctx context.Context, db *gorm.DB, doneCh chan struct{}) {
log := zerolog.Ctx(ctx)
defer close(doneCh)
type consumerHandle struct {
cancel context.CancelFunc
consumer *Consumer
}
running := map[string]consumerHandle{}
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
t := make(chan time.Time, 1)
t <- time.Now()
for {
select {
case <-t:
remotes := []pds.PDS{}
if err := db.Find(&remotes).Error; err != nil {
log.Error().Err(err).Msgf("Failed to get a list of known PDSs: %s", err)
break
}
shouldBeRunning := map[string]pds.PDS{}
for _, remote := range remotes {
if remote.Disabled {
continue
}
shouldBeRunning[remote.Host] = remote
}
for host, handle := range running {
if _, found := shouldBeRunning[host]; found {
continue
}
handle.cancel()
_ = handle.consumer.Wait(ctx)
delete(running, host)
}
for host, remote := range shouldBeRunning {
if _, found := running[host]; found {
continue
}
subCtx, cancel := context.WithCancel(ctx)
c, err := NewConsumer(subCtx, &remote, db)
if err != nil {
log.Error().Err(err).Msgf("Failed to create a consumer for %q: %s", remote.Host, err)
cancel()
continue
}
if err := c.Start(subCtx); err != nil {
log.Error().Err(err).Msgf("Failed ot start a consumer for %q: %s", remote.Host, err)
cancel()
continue
}
running[host] = consumerHandle{
cancel: cancel,
consumer: c,
}
}
case <-ctx.Done():
var wg sync.WaitGroup
for host, handle := range running {
wg.Add(1)
go func(handle consumerHandle) {
handle.cancel()
_ = handle.consumer.Wait(ctx)
wg.Done()
}(handle)
delete(running, host)
}
wg.Wait()
case v := <-ticker.C:
// Non-blocking send.
select {
case t <- v:
default:
}
}
}
}
func main() {
flag.StringVar(&config.LogFile, "log", "", "Path to the log file. If empty, will log to stderr")
flag.StringVar(&config.LogFormat, "log-format", "text", "Logging format. 'text' or 'json'")
flag.Int64Var(&config.LogLevel, "log-level", 1, "Log level. -1 - trace, 0 - debug, 1 - info, 5 - panic")
flag.StringVar(&config.Relays, "relays", "", "List of relays to connect to (for discovering new PDSs)")
if err := envconfig.Process("consumer", &config); err != nil {
log.Fatalf("envconfig.Process: %s", err)
}
flag.Parse()
ctx, _ := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
if err := runMain(ctx); err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
}
func setupLogging(ctx context.Context) context.Context {
logFile := os.Stderr
if config.LogFile != "" {
f, err := os.OpenFile(config.LogFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
log.Fatalf("Failed to open the specified log file %q: %s", config.LogFile, err)
}
logFile = f
}
var output io.Writer
switch config.LogFormat {
case "json":
output = logFile
case "text":
prefixList := []string{}
info, ok := debug.ReadBuildInfo()
if ok {
prefixList = append(prefixList, info.Path+"/")
}
basedir := ""
_, sourceFile, _, ok := runtime.Caller(0)
if ok {
basedir = filepath.Dir(sourceFile)
}
if basedir != "" && strings.HasPrefix(basedir, "/") {
prefixList = append(prefixList, basedir+"/")
head, _ := filepath.Split(basedir)
for head != "/" {
prefixList = append(prefixList, head)
head, _ = filepath.Split(strings.TrimSuffix(head, "/"))
}
}
output = zerolog.ConsoleWriter{
Out: logFile,
NoColor: true,
TimeFormat: time.RFC3339,
PartsOrder: []string{
zerolog.LevelFieldName,
zerolog.TimestampFieldName,
zerolog.CallerFieldName,
zerolog.MessageFieldName,
},
FormatFieldName: func(i interface{}) string { return fmt.Sprintf("%s:", i) },
FormatFieldValue: func(i interface{}) string { return fmt.Sprintf("%s", i) },
FormatCaller: func(i interface{}) string {
s := i.(string)
for _, p := range prefixList {
s = strings.TrimPrefix(s, p)
}
return s
},
}
default:
log.Fatalf("Invalid log format specified: %q", config.LogFormat)
}
logger := zerolog.New(output).Level(zerolog.Level(config.LogLevel)).With().Caller().Timestamp().Logger()
ctx = logger.WithContext(ctx)
zerolog.DefaultContextLogger = &logger
log.SetOutput(logger)
return ctx
}

View file

@ -1,36 +0,0 @@
package main
import (
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
)
var lastEventTimestamp = promauto.NewGaugeVec(prometheus.GaugeOpts{
Name: "repo_commit_received_timestamp",
Help: "Timestamp of the last event received from firehose.",
}, []string{"remote"})
var eventCounter = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "repo_commits_received_counter",
Help: "Counter of events received from each remote.",
}, []string{"remote", "type"})
var reposDiscovered = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "repo_discovered_counter",
Help: "Counter of newly discovered repos",
}, []string{"remote"})
var postsByLanguageIndexed = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "indexer_posts_by_language_count",
Help: "Number of posts by language",
}, []string{"remote", "lang"})
var connectionFailures = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "consumer_connection_failures",
Help: "Counter of firehose connection failures",
}, []string{"remote"})
var pdsOnline = promauto.NewGaugeVec(prometheus.GaugeOpts{
Name: "consumer_connection_up",
Help: "Status of a connection. 1 - up and running.",
}, []string{"remote"})

View file

@ -1,179 +0,0 @@
package main
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"net/url"
"path"
"time"
comatproto "github.com/bluesky-social/indigo/api/atproto"
"github.com/gorilla/websocket"
"github.com/ipld/go-ipld-prime/codec/dagcbor"
"github.com/ipld/go-ipld-prime/node/basicnode"
"github.com/rs/zerolog"
"github.com/uabluerail/indexer/pds"
"github.com/uabluerail/indexer/util/resolver"
"gorm.io/gorm"
)
type RelayConsumer struct {
url string
db *gorm.DB
}
func NewRelayConsumer(ctx context.Context, host string, db *gorm.DB) (*RelayConsumer, error) {
addr, err := url.Parse(host)
if err != nil {
return nil, fmt.Errorf("parsing URL %q: %s", host, err)
}
addr.Scheme = "wss"
addr.Path = path.Join(addr.Path, "xrpc/com.atproto.sync.subscribeRepos")
return &RelayConsumer{db: db, url: addr.String()}, nil
}
func (c *RelayConsumer) Start(ctx context.Context) {
go c.run(ctx)
}
func (c *RelayConsumer) run(ctx context.Context) {
log := zerolog.Ctx(ctx).With().Str("relay", c.url).Logger()
ctx = log.WithContext(ctx)
for {
select {
case <-ctx.Done():
log.Info().Msgf("Relay consumer stopped")
return
default:
if err := c.runOnce(ctx); err != nil {
log.Error().Err(err).Msgf("Consumer of relay %q failed (will be restarted): %s", c.url, err)
}
time.Sleep(time.Second)
}
}
}
func (c *RelayConsumer) runOnce(ctx context.Context) error {
log := zerolog.Ctx(ctx)
conn, _, err := websocket.DefaultDialer.DialContext(ctx, c.url, http.Header{})
if err != nil {
return fmt.Errorf("establishing websocker connection: %w", err)
}
defer conn.Close()
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
_, b, err := conn.ReadMessage()
if err != nil {
return fmt.Errorf("websocket.ReadMessage: %w", err)
}
r := bytes.NewReader(b)
proto := basicnode.Prototype.Any
headerNode := proto.NewBuilder()
if err := (&dagcbor.DecodeOptions{DontParseBeyondEnd: true}).Decode(headerNode, r); err != nil {
return fmt.Errorf("unmarshaling message header: %w", err)
}
header, err := parseHeader(headerNode.Build())
if err != nil {
return fmt.Errorf("parsing message header: %w", err)
}
switch header.Op {
case 1:
if err := c.processMessage(ctx, header.Type, r); err != nil {
log.Info().Err(err).Msgf("Relay consumer failed to process a message: %s", err)
}
case -1:
bodyNode := proto.NewBuilder()
if err := (&dagcbor.DecodeOptions{DontParseBeyondEnd: true, AllowLinks: true}).Decode(bodyNode, r); err != nil {
return fmt.Errorf("unmarshaling message body: %w", err)
}
body, err := parseError(bodyNode.Build())
if err != nil {
return fmt.Errorf("parsing error payload: %w", err)
}
return &body
default:
log.Warn().Msgf("Unknown 'op' value received: %d", header.Op)
}
}
}
}
func (c *RelayConsumer) processMessage(ctx context.Context, typ string, r io.Reader) error {
log := zerolog.Ctx(ctx)
did := ""
switch typ {
case "#commit":
payload := &comatproto.SyncSubscribeRepos_Commit{}
if err := payload.UnmarshalCBOR(r); err != nil {
return fmt.Errorf("failed to unmarshal commit: %w", err)
}
did = payload.Repo
case "#handle":
payload := &comatproto.SyncSubscribeRepos_Handle{}
if err := payload.UnmarshalCBOR(r); err != nil {
return fmt.Errorf("failed to unmarshal commit: %w", err)
}
did = payload.Did
case "#migrate":
payload := &comatproto.SyncSubscribeRepos_Migrate{}
if err := payload.UnmarshalCBOR(r); err != nil {
return fmt.Errorf("failed to unmarshal commit: %w", err)
}
did = payload.Did
case "#tombstone":
payload := &comatproto.SyncSubscribeRepos_Tombstone{}
if err := payload.UnmarshalCBOR(r); err != nil {
return fmt.Errorf("failed to unmarshal commit: %w", err)
}
did = payload.Did
case "#info":
// Ignore
case "#identity":
payload := &comatproto.SyncSubscribeRepos_Identity{}
if err := payload.UnmarshalCBOR(r); err != nil {
return fmt.Errorf("failed to unmarshal commit: %w", err)
}
did = payload.Did
default:
b, err := io.ReadAll(r)
if err != nil {
log.Error().Err(err).Msgf("Failed to read message payload: %s", err)
}
log.Warn().Msgf("Unknown message type received: %s payload=%q", typ, string(b))
}
if did == "" {
return nil
}
u, _, err := resolver.GetPDSEndpointAndPublicKey(ctx, did)
if err != nil {
return err
}
_, err = pds.EnsureExists(ctx, c.db, u.String())
return err
}

View file

@ -1,14 +0,0 @@
FROM golang:1.22.3 as builder
WORKDIR /app
COPY go.mod go.sum ./
RUN go mod download
COPY . ./
RUN go build -trimpath ./cmd/lister
FROM alpine:latest as certs
RUN apk --update add ca-certificates
FROM debian:stable-slim
COPY --from=certs /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ca-certificates.crt
COPY --from=builder /app/lister .
ENTRYPOINT ["./lister"]

View file

@ -1,7 +0,0 @@
*
**/*
!go.mod
!go.sum
!**/*.go
cmd/**
!cmd/lister

View file

@ -1,151 +0,0 @@
package main
import (
"context"
"errors"
"time"
"github.com/rs/zerolog"
"gorm.io/gorm"
comatproto "github.com/bluesky-social/indigo/api/atproto"
"github.com/bluesky-social/indigo/did"
"github.com/uabluerail/bsky-tools/pagination"
"github.com/uabluerail/bsky-tools/xrpcauth"
"github.com/uabluerail/indexer/pds"
"github.com/uabluerail/indexer/repo"
"github.com/uabluerail/indexer/util/resolver"
)
type Lister struct {
db *gorm.DB
resolver did.Resolver
pollInterval time.Duration
listRefreshInterval time.Duration
}
func NewLister(ctx context.Context, db *gorm.DB) (*Lister, error) {
return &Lister{
db: db,
resolver: resolver.Resolver,
pollInterval: 5 * time.Minute,
listRefreshInterval: 24 * time.Hour,
}, nil
}
func (l *Lister) Start(ctx context.Context) error {
go l.run(ctx)
return nil
}
func (l *Lister) run(ctx context.Context) {
log := zerolog.Ctx(ctx)
ticker := time.NewTicker(l.pollInterval)
log.Info().Msgf("Lister starting...")
t := make(chan time.Time, 1)
t <- time.Now()
for {
select {
case <-ctx.Done():
log.Info().Msgf("Lister stopped (context expired)")
return
case <-t:
db := l.db.WithContext(ctx)
remote := pds.PDS{}
if err := db.Model(&remote).
Where("(disabled=false or disabled is null) and (last_list is null or last_list < ?)", time.Now().Add(-l.listRefreshInterval)).
Take(&remote).Error; err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
log.Error().Err(err).Msgf("Failed to query DB for a PDS to list repos from: %s", err)
}
break
}
if !pds.IsWhitelisted(remote.Host) {
log.Info().Msgf("PDS %q is not whitelisted, disabling it", remote.Host)
if err := db.Model(&remote).Where(&pds.PDS{ID: remote.ID}).Updates(&pds.PDS{Disabled: true}).Error; err != nil {
log.Error().Err(err).Msgf("Failed to disable PDS %q: %s", remote.Host, err)
}
break
}
client := xrpcauth.NewAnonymousClient(ctx)
client.Host = remote.Host
log.Info().Msgf("Listing repos from %q...", remote.Host)
repos, err := pagination.Reduce(
func(cursor string) (resp *comatproto.SyncListRepos_Output, nextCursor string, err error) {
resp, err = comatproto.SyncListRepos(ctx, client, cursor, 200)
if err == nil && resp.Cursor != nil {
nextCursor = *resp.Cursor
}
return
},
func(resp *comatproto.SyncListRepos_Output, acc []*comatproto.SyncListRepos_Repo) ([]*comatproto.SyncListRepos_Repo, error) {
for _, repo := range resp.Repos {
if repo == nil {
continue
}
acc = append(acc, repo)
}
return acc, nil
})
if err != nil {
log.Error().Err(err).Msgf("Failed to list repos from %q: %s", remote.Host, err)
// Update the timestamp so we don't get stuck on a single broken PDS
if err := db.Model(&remote).Updates(&pds.PDS{LastList: time.Now()}).Error; err != nil {
log.Error().Err(err).Msgf("Failed to update the timestamp of last list for %q: %s", remote.Host, err)
}
break
}
log.Info().Msgf("Received %d DIDs from %q", len(repos), remote.Host)
reposListed.WithLabelValues(remote.Host).Add(float64(len(repos)))
for _, repoInfo := range repos {
record, created, err := repo.EnsureExists(ctx, l.db, repoInfo.Did)
if err != nil {
log.Error().Err(err).Msgf("Failed to ensure that we have a record for the repo %q: %s", repoInfo.Did, err)
} else if created {
reposDiscovered.WithLabelValues(remote.Host).Inc()
}
if err == nil && record.FirstRevSinceReset == "" {
// Populate this field in case it's empty, so we don't have to wait for the first firehose event
// to trigger a resync.
err := l.db.Transaction(func(tx *gorm.DB) error {
var currentRecord repo.Repo
if err := tx.Model(&record).Where(&repo.Repo{ID: record.ID}).Take(&currentRecord).Error; err != nil {
return err
}
if currentRecord.FirstRevSinceReset != "" {
// Someone else already updated it, nothing to do.
return nil
}
var remote pds.PDS
if err := tx.Model(&remote).Where(&pds.PDS{ID: record.PDS}).Take(&remote).Error; err != nil {
return err
}
return tx.Model(&record).Where(&repo.Repo{ID: record.ID}).Updates(&repo.Repo{
FirstRevSinceReset: repoInfo.Rev,
FirstCursorSinceReset: remote.FirstCursorSinceReset,
}).Error
})
if err != nil {
log.Error().Err(err).Msgf("Failed to set the initial FirstRevSinceReset value for %q: %s", repoInfo.Did, err)
}
}
}
if err := db.Model(&remote).Updates(&pds.PDS{LastList: time.Now()}).Error; err != nil {
log.Error().Err(err).Msgf("Failed to update the timestamp of last list for %q: %s", remote.Host, err)
}
case v := <-ticker.C:
t <- v
}
}
}

View file

@ -1,168 +0,0 @@
package main
import (
"context"
"flag"
"fmt"
"io"
"log"
"net/http"
_ "net/http/pprof"
"os"
"os/signal"
"path/filepath"
"runtime"
"runtime/debug"
"strings"
"syscall"
"time"
_ "github.com/joho/godotenv/autoload"
"github.com/kelseyhightower/envconfig"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/rs/zerolog"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"github.com/uabluerail/indexer/util/gormzerolog"
)
type Config struct {
LogFile string
LogFormat string `default:"text"`
LogLevel int64 `default:"1"`
MetricsPort string `split_words:"true"`
DBUrl string `envconfig:"POSTGRES_URL"`
}
var config Config
func runMain(ctx context.Context) error {
ctx = setupLogging(ctx)
log := zerolog.Ctx(ctx)
log.Debug().Msgf("Starting up...")
db, err := gorm.Open(postgres.Open(config.DBUrl), &gorm.Config{
Logger: gormzerolog.New(&logger.Config{
SlowThreshold: 1 * time.Second,
IgnoreRecordNotFoundError: true,
}, nil),
})
if err != nil {
return fmt.Errorf("connecting to the database: %w", err)
}
log.Debug().Msgf("DB connection established")
lister, err := NewLister(ctx, db)
if err != nil {
return fmt.Errorf("failed to create lister: %w", err)
}
if err := lister.Start(ctx); err != nil {
return fmt.Errorf("failed to start lister: %w", err)
}
log.Info().Msgf("Starting HTTP listener on %q...", config.MetricsPort)
http.Handle("/metrics", promhttp.Handler())
srv := &http.Server{Addr: fmt.Sprintf(":%s", config.MetricsPort)}
errCh := make(chan error)
go func() {
errCh <- srv.ListenAndServe()
}()
select {
case <-ctx.Done():
if err := srv.Shutdown(context.Background()); err != nil {
return fmt.Errorf("HTTP server shutdown failed: %w", err)
}
}
return <-errCh
}
func main() {
flag.StringVar(&config.LogFile, "log", "", "Path to the log file. If empty, will log to stderr")
flag.StringVar(&config.LogFormat, "log-format", "text", "Logging format. 'text' or 'json'")
flag.Int64Var(&config.LogLevel, "log-level", 1, "Log level. -1 - trace, 0 - debug, 1 - info, 5 - panic")
if err := envconfig.Process("lister", &config); err != nil {
log.Fatalf("envconfig.Process: %s", err)
}
flag.Parse()
ctx, _ := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
if err := runMain(ctx); err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
}
func setupLogging(ctx context.Context) context.Context {
logFile := os.Stderr
if config.LogFile != "" {
f, err := os.OpenFile(config.LogFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
log.Fatalf("Failed to open the specified log file %q: %s", config.LogFile, err)
}
logFile = f
}
var output io.Writer
switch config.LogFormat {
case "json":
output = logFile
case "text":
prefixList := []string{}
info, ok := debug.ReadBuildInfo()
if ok {
prefixList = append(prefixList, info.Path+"/")
}
basedir := ""
_, sourceFile, _, ok := runtime.Caller(0)
if ok {
basedir = filepath.Dir(sourceFile)
}
if basedir != "" && strings.HasPrefix(basedir, "/") {
prefixList = append(prefixList, basedir+"/")
head, _ := filepath.Split(basedir)
for head != "/" {
prefixList = append(prefixList, head)
head, _ = filepath.Split(strings.TrimSuffix(head, "/"))
}
}
output = zerolog.ConsoleWriter{
Out: logFile,
NoColor: true,
TimeFormat: time.RFC3339,
PartsOrder: []string{
zerolog.LevelFieldName,
zerolog.TimestampFieldName,
zerolog.CallerFieldName,
zerolog.MessageFieldName,
},
FormatFieldName: func(i interface{}) string { return fmt.Sprintf("%s:", i) },
FormatFieldValue: func(i interface{}) string { return fmt.Sprintf("%s", i) },
FormatCaller: func(i interface{}) string {
s := i.(string)
for _, p := range prefixList {
s = strings.TrimPrefix(s, p)
}
return s
},
}
default:
log.Fatalf("Invalid log format specified: %q", config.LogFormat)
}
logger := zerolog.New(output).Level(zerolog.Level(config.LogLevel)).With().Caller().Timestamp().Logger()
ctx = logger.WithContext(ctx)
zerolog.DefaultContextLogger = &logger
log.SetOutput(logger)
return ctx
}

View file

@ -1,16 +0,0 @@
package main
import (
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
)
var reposDiscovered = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "repo_discovered_counter",
Help: "Counter of newly discovered repos",
}, []string{"remote"})
var reposListed = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "repo_listed_counter",
Help: "Counter of repos received by listing PDSs.",
}, []string{"remote"})

View file

@ -1,14 +0,0 @@
FROM golang:1.22.3 as builder
WORKDIR /app
COPY go.mod go.sum ./
RUN go mod download
COPY . ./
RUN go build -trimpath ./cmd/record-indexer
FROM alpine:latest as certs
RUN apk --update add ca-certificates
FROM debian:stable-slim
COPY --from=certs /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ca-certificates.crt
COPY --from=builder /app/record-indexer .
ENTRYPOINT ["./record-indexer"]

View file

@ -1,7 +0,0 @@
*
**/*
!go.mod
!go.sum
!**/*.go
cmd/**
!cmd/record-indexer

View file

@ -1,77 +0,0 @@
package main
import (
"context"
"fmt"
"net/http"
"strconv"
"golang.org/x/time/rate"
)
func AddAdminHandlers(limiter *Limiter, pool *WorkerPool) {
http.HandleFunc("/rate/set", handleRateSet(limiter))
http.HandleFunc("/rate/setAll", handleRateSetAll(limiter))
http.HandleFunc("/pool/resize", handlePoolResize(pool))
}
func handlePoolResize(pool *WorkerPool) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
s := r.FormValue("size")
if s == "" {
http.Error(w, "need size", http.StatusBadRequest)
return
}
size, err := strconv.Atoi(s)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
pool.Resize(context.Background(), size)
fmt.Fprintln(w, "OK")
}
}
func handleRateSet(limiter *Limiter) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
s := r.FormValue("limit")
if s == "" {
http.Error(w, "need limit", http.StatusBadRequest)
return
}
name := r.FormValue("name")
if name == "" {
http.Error(w, "need name", http.StatusBadRequest)
return
}
limit, err := strconv.Atoi(s)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
limiter.SetLimit(context.Background(), name, rate.Limit(limit))
fmt.Fprintln(w, "OK")
}
}
func handleRateSetAll(limiter *Limiter) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
s := r.FormValue("limit")
if s == "" {
http.Error(w, "need limit", http.StatusBadRequest)
return
}
limit, err := strconv.Atoi(s)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
limiter.SetAllLimits(context.Background(), rate.Limit(limit))
fmt.Fprintln(w, "OK")
}
}

View file

@ -1,179 +0,0 @@
package main
import (
"context"
"flag"
"fmt"
"io"
"log"
"net/http"
_ "net/http/pprof"
"os"
"os/signal"
"path/filepath"
"runtime"
"runtime/debug"
"strings"
"syscall"
"time"
_ "github.com/joho/godotenv/autoload"
"github.com/kelseyhightower/envconfig"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/rs/zerolog"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"github.com/uabluerail/indexer/util/gormzerolog"
)
type Config struct {
LogFile string
LogFormat string `default:"text"`
LogLevel int64 `default:"1"`
MetricsPort string `split_words:"true"`
DBUrl string `envconfig:"POSTGRES_URL"`
Workers int `default:"2"`
}
var config Config
func runMain(ctx context.Context) error {
ctx = setupLogging(ctx)
log := zerolog.Ctx(ctx)
log.Debug().Msgf("Starting up...")
db, err := gorm.Open(postgres.Open(config.DBUrl), &gorm.Config{
Logger: gormzerolog.New(&logger.Config{
SlowThreshold: 3 * time.Second,
IgnoreRecordNotFoundError: true,
}, nil),
})
if err != nil {
return fmt.Errorf("connecting to the database: %w", err)
}
log.Debug().Msgf("DB connection established")
limiter, err := NewLimiter(db)
if err != nil {
return fmt.Errorf("failed to create limiter: %w", err)
}
ch := make(chan WorkItem)
pool := NewWorkerPool(ch, db, config.Workers, limiter)
if err := pool.Start(ctx); err != nil {
return fmt.Errorf("failed to start worker pool: %w", err)
}
scheduler := NewScheduler(ch, db)
if err := scheduler.Start(ctx); err != nil {
return fmt.Errorf("failed to start scheduler: %w", err)
}
log.Info().Msgf("Starting HTTP listener on %q...", config.MetricsPort)
AddAdminHandlers(limiter, pool)
http.Handle("/metrics", promhttp.Handler())
srv := &http.Server{Addr: fmt.Sprintf(":%s", config.MetricsPort)}
errCh := make(chan error)
go func() {
errCh <- srv.ListenAndServe()
}()
select {
case <-ctx.Done():
if err := srv.Shutdown(context.Background()); err != nil {
return fmt.Errorf("HTTP server shutdown failed: %w", err)
}
}
return <-errCh
}
func main() {
flag.StringVar(&config.LogFile, "log", "", "Path to the log file. If empty, will log to stderr")
flag.StringVar(&config.LogFormat, "log-format", "text", "Logging format. 'text' or 'json'")
flag.Int64Var(&config.LogLevel, "log-level", 1, "Log level. -1 - trace, 0 - debug, 1 - info, 5 - panic")
flag.IntVar(&config.Workers, "workers", 2, "Number of workers to start with")
if err := envconfig.Process("indexer", &config); err != nil {
log.Fatalf("envconfig.Process: %s", err)
}
flag.Parse()
ctx, _ := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
if err := runMain(ctx); err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
}
func setupLogging(ctx context.Context) context.Context {
logFile := os.Stderr
if config.LogFile != "" {
f, err := os.OpenFile(config.LogFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
log.Fatalf("Failed to open the specified log file %q: %s", config.LogFile, err)
}
logFile = f
}
var output io.Writer
switch config.LogFormat {
case "json":
output = logFile
case "text":
prefixList := []string{}
info, ok := debug.ReadBuildInfo()
if ok {
prefixList = append(prefixList, info.Path+"/")
}
basedir := ""
_, sourceFile, _, ok := runtime.Caller(0)
if ok {
basedir = filepath.Dir(sourceFile)
}
if basedir != "" && strings.HasPrefix(basedir, "/") {
prefixList = append(prefixList, basedir+"/")
head, _ := filepath.Split(basedir)
for head != "/" {
prefixList = append(prefixList, head)
head, _ = filepath.Split(strings.TrimSuffix(head, "/"))
}
}
output = zerolog.ConsoleWriter{
Out: logFile,
NoColor: true,
TimeFormat: time.RFC3339,
PartsOrder: []string{
zerolog.LevelFieldName,
zerolog.TimestampFieldName,
zerolog.CallerFieldName,
zerolog.MessageFieldName,
},
FormatFieldName: func(i interface{}) string { return fmt.Sprintf("%s:", i) },
FormatFieldValue: func(i interface{}) string { return fmt.Sprintf("%s", i) },
FormatCaller: func(i interface{}) string {
s := i.(string)
for _, p := range prefixList {
s = strings.TrimPrefix(s, p)
}
return s
},
}
default:
log.Fatalf("Invalid log format specified: %q", config.LogFormat)
}
logger := zerolog.New(output).Level(zerolog.Level(config.LogLevel)).With().Caller().Timestamp().Logger()
ctx = logger.WithContext(ctx)
zerolog.DefaultContextLogger = &logger
log.SetOutput(logger)
return ctx
}

View file

@ -1,41 +0,0 @@
package main
import (
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
)
var reposQueued = promauto.NewCounter(prometheus.CounterOpts{
Name: "indexer_repos_queued_count",
Help: "Number of repos added to the queue",
})
var queueLenght = promauto.NewGaugeVec(prometheus.GaugeOpts{
Name: "indexer_queue_length",
Help: "Current length of indexing queue",
}, []string{"state"})
var reposFetched = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "indexer_repos_fetched_count",
Help: "Number of repos fetched",
}, []string{"remote", "success"})
var reposIndexed = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "indexer_repos_indexed_count",
Help: "Number of repos indexed",
}, []string{"success"})
var recordsFetched = promauto.NewCounter(prometheus.CounterOpts{
Name: "indexer_records_fetched_count",
Help: "Number of records fetched",
})
var recordsInserted = promauto.NewCounter(prometheus.CounterOpts{
Name: "indexer_records_inserted_count",
Help: "Number of records inserted into DB",
})
var workerPoolSize = promauto.NewGauge(prometheus.GaugeOpts{
Name: "indexer_workers_count",
Help: "Current number of workers running",
})

View file

@ -1,82 +0,0 @@
package main
import (
"context"
"fmt"
"sync"
"github.com/rs/zerolog"
"github.com/uabluerail/indexer/pds"
"golang.org/x/time/rate"
"gorm.io/gorm"
)
const defaultRateLimit = 10
type Limiter struct {
mu sync.RWMutex
db *gorm.DB
limiter map[string]*rate.Limiter
}
func NewLimiter(db *gorm.DB) (*Limiter, error) {
remotes := []pds.PDS{}
if err := db.Find(&remotes).Error; err != nil {
return nil, fmt.Errorf("failed to get the list of known PDSs: %w", err)
}
l := &Limiter{
db: db,
limiter: map[string]*rate.Limiter{},
}
for _, remote := range remotes {
limit := remote.CrawlLimit
if limit == 0 {
limit = defaultRateLimit
}
l.limiter[remote.Host] = rate.NewLimiter(rate.Limit(limit), limit*2)
}
return l, nil
}
func (l *Limiter) getLimiter(name string) *rate.Limiter {
l.mu.RLock()
limiter := l.limiter[name]
l.mu.RUnlock()
if limiter != nil {
return limiter
}
limiter = rate.NewLimiter(defaultRateLimit, defaultRateLimit*2)
l.mu.Lock()
l.limiter[name] = limiter
l.mu.Unlock()
return limiter
}
func (l *Limiter) Wait(ctx context.Context, name string) error {
return l.getLimiter(name).Wait(ctx)
}
func (l *Limiter) SetLimit(ctx context.Context, name string, limit rate.Limit) {
l.getLimiter(name).SetLimit(limit)
err := l.db.Model(&pds.PDS{}).Where(&pds.PDS{Host: name}).Updates(&pds.PDS{CrawlLimit: int(limit)}).Error
if err != nil {
zerolog.Ctx(ctx).Error().Err(err).Msgf("Failed to persist rate limit change for %q: %s", name, err)
}
}
func (l *Limiter) SetAllLimits(ctx context.Context, limit rate.Limit) {
l.mu.RLock()
for name, limiter := range l.limiter {
limiter.SetLimit(limit)
err := l.db.Model(&pds.PDS{}).Where(&pds.PDS{Host: name}).Updates(&pds.PDS{CrawlLimit: int(limit)}).Error
if err != nil {
zerolog.Ctx(ctx).Error().Err(err).Msgf("Failed to persist rate limit change for %q: %s", name, err)
}
}
l.mu.RUnlock()
}

View file

@ -1,158 +0,0 @@
package main
import (
"context"
"fmt"
"slices"
"time"
"github.com/rs/zerolog"
"github.com/uabluerail/indexer/pds"
"github.com/uabluerail/indexer/repo"
"gorm.io/gorm"
)
type Scheduler struct {
db *gorm.DB
output chan<- WorkItem
queue map[string]*repo.Repo
inProgress map[string]*repo.Repo
}
func NewScheduler(output chan<- WorkItem, db *gorm.DB) *Scheduler {
return &Scheduler{
db: db,
output: output,
queue: map[string]*repo.Repo{},
inProgress: map[string]*repo.Repo{},
}
}
func (s *Scheduler) Start(ctx context.Context) error {
go s.run(ctx)
return nil
}
func (s *Scheduler) run(ctx context.Context) {
log := zerolog.Ctx(ctx)
t := time.NewTicker(time.Minute)
defer t.Stop()
if err := s.fillQueue(ctx); err != nil {
log.Error().Err(err).Msgf("Failed to get more tasks for the queue: %s", err)
}
done := make(chan string)
for {
if len(s.queue) > 0 {
next := WorkItem{signal: make(chan struct{})}
for _, r := range s.queue {
next.Repo = r
break
}
select {
case <-ctx.Done():
return
case <-t.C:
if err := s.fillQueue(ctx); err != nil {
log.Error().Err(err).Msgf("Failed to get more tasks for the queue: %s", err)
}
case s.output <- next:
delete(s.queue, next.Repo.DID)
s.inProgress[next.Repo.DID] = next.Repo
go func(did string, ch chan struct{}) {
select {
case <-ch:
case <-ctx.Done():
}
done <- did
}(next.Repo.DID, next.signal)
s.updateQueueLenMetrics()
case did := <-done:
delete(s.inProgress, did)
s.updateQueueLenMetrics()
}
} else {
select {
case <-ctx.Done():
return
case <-t.C:
if err := s.fillQueue(ctx); err != nil {
log.Error().Err(err).Msgf("Failed to get more tasks for the queue: %s", err)
}
case did := <-done:
delete(s.inProgress, did)
s.updateQueueLenMetrics()
}
}
}
}
func (s *Scheduler) fillQueue(ctx context.Context) error {
const maxQueueLen = 10000
const maxAttempts = 3
if len(s.queue)+len(s.inProgress) >= maxQueueLen {
return nil
}
remotes := []pds.PDS{}
if err := s.db.Find(&remotes).Error; err != nil {
return fmt.Errorf("failed to get the list of PDSs: %w", err)
}
remotes = slices.DeleteFunc(remotes, func(pds pds.PDS) bool {
return pds.Disabled
})
perPDSLimit := maxQueueLen
if len(remotes) > 0 {
perPDSLimit = maxQueueLen * 2 / len(remotes)
}
if perPDSLimit < maxQueueLen/10 {
perPDSLimit = maxQueueLen / 10
}
// Fake remote to account for repos we didn't have a PDS for yet.
remotes = append(remotes, pds.PDS{ID: pds.Unknown})
for _, remote := range remotes {
repos := []repo.Repo{}
err := s.db.Raw(`SELECT * FROM "repos" WHERE pds = ? AND (last_indexed_rev is null OR last_indexed_rev = '') AND failed_attempts < ?
UNION
SELECT "repos".* FROM "repos" left join "pds" on repos.pds = pds.id WHERE pds = ?
AND
(
(first_rev_since_reset is not null AND first_rev_since_reset <> ''
AND last_indexed_rev < first_rev_since_reset)
OR
("repos".first_cursor_since_reset is not null AND "repos".first_cursor_since_reset <> 0
AND "repos".first_cursor_since_reset < "pds".first_cursor_since_reset)
)
AND failed_attempts < ? LIMIT ?`,
remote.ID, maxAttempts, remote.ID, maxAttempts, perPDSLimit).
Scan(&repos).Error
if err != nil {
return fmt.Errorf("querying DB: %w", err)
}
for _, r := range repos {
if s.queue[r.DID] != nil || s.inProgress[r.DID] != nil {
continue
}
copied := r
s.queue[r.DID] = &copied
reposQueued.Inc()
}
s.updateQueueLenMetrics()
}
return nil
}
func (s *Scheduler) updateQueueLenMetrics() {
queueLenght.WithLabelValues("queued").Set(float64(len(s.queue)))
queueLenght.WithLabelValues("inProgress").Set(float64(len(s.inProgress)))
}

View file

@ -1,315 +0,0 @@
package main
import (
"bytes"
"context"
"fmt"
"regexp"
"strings"
"time"
"github.com/imax9000/errors"
"github.com/rs/zerolog"
"gorm.io/gorm"
"gorm.io/gorm/clause"
comatproto "github.com/bluesky-social/indigo/api/atproto"
"github.com/bluesky-social/indigo/util"
"github.com/bluesky-social/indigo/xrpc"
"github.com/uabluerail/bsky-tools/xrpcauth"
"github.com/uabluerail/indexer/models"
"github.com/uabluerail/indexer/pds"
"github.com/uabluerail/indexer/repo"
"github.com/uabluerail/indexer/util/fix"
"github.com/uabluerail/indexer/util/resolver"
)
type WorkItem struct {
Repo *repo.Repo
signal chan struct{}
}
type WorkerPool struct {
db *gorm.DB
input <-chan WorkItem
limiter *Limiter
workerSignals []chan struct{}
resize chan int
}
func NewWorkerPool(input <-chan WorkItem, db *gorm.DB, size int, limiter *Limiter) *WorkerPool {
r := &WorkerPool{
db: db,
input: input,
limiter: limiter,
resize: make(chan int),
}
r.workerSignals = make([]chan struct{}, size)
for i := range r.workerSignals {
r.workerSignals[i] = make(chan struct{})
}
return r
}
func (p *WorkerPool) Start(ctx context.Context) error {
go p.run(ctx)
return nil
}
func (p *WorkerPool) Resize(ctx context.Context, size int) error {
select {
case <-ctx.Done():
return ctx.Err()
case p.resize <- size:
return nil
}
}
func (p *WorkerPool) run(ctx context.Context) {
for _, ch := range p.workerSignals {
go p.worker(ctx, ch)
}
workerPoolSize.Set(float64(len(p.workerSignals)))
for {
select {
case <-ctx.Done():
for _, ch := range p.workerSignals {
close(ch)
}
// also wait for all workers to stop?
return
case newSize := <-p.resize:
switch {
case newSize > len(p.workerSignals):
ch := make([]chan struct{}, newSize-len(p.workerSignals))
for i := range ch {
ch[i] = make(chan struct{})
go p.worker(ctx, ch[i])
}
p.workerSignals = append(p.workerSignals, ch...)
workerPoolSize.Set(float64(len(p.workerSignals)))
case newSize < len(p.workerSignals) && newSize > 0:
for _, ch := range p.workerSignals[newSize:] {
close(ch)
}
p.workerSignals = p.workerSignals[:newSize]
workerPoolSize.Set(float64(len(p.workerSignals)))
}
}
}
}
func (p *WorkerPool) worker(ctx context.Context, signal chan struct{}) {
log := zerolog.Ctx(ctx)
for {
select {
case <-ctx.Done():
return
case <-signal:
return
case work := <-p.input:
updates := &repo.Repo{}
if err := p.doWork(ctx, work); err != nil {
log.Error().Err(err).Msgf("Work task %q failed: %s", work.Repo.DID, err)
updates.LastError = err.Error()
updates.FailedAttempts = work.Repo.FailedAttempts + 1
reposIndexed.WithLabelValues("false").Inc()
} else {
updates.FailedAttempts = 0
reposIndexed.WithLabelValues("true").Inc()
}
updates.LastIndexAttempt = time.Now()
err := p.db.Model(&repo.Repo{}).
Where(&repo.Repo{ID: work.Repo.ID}).
Select("last_error", "last_index_attempt", "failed_attempts").
Updates(updates).Error
if err != nil {
log.Error().Err(err).Msgf("Failed to update repo info for %q: %s", work.Repo.DID, err)
}
}
}
}
func (p *WorkerPool) doWork(ctx context.Context, work WorkItem) error {
log := zerolog.Ctx(ctx)
defer close(work.signal)
u, pubKey, err := resolver.GetPDSEndpointAndPublicKey(ctx, work.Repo.DID)
if err != nil {
return err
}
remote, err := pds.EnsureExists(ctx, p.db, u.String())
if err != nil {
return fmt.Errorf("failed to get PDS records for %q: %w", u, err)
}
if work.Repo.PDS != remote.ID {
if err := p.db.Model(&work.Repo).Where(&repo.Repo{ID: work.Repo.ID}).Updates(&repo.Repo{PDS: remote.ID}).Error; err != nil {
return fmt.Errorf("failed to update repo's PDS to %q: %w", u, err)
}
work.Repo.PDS = remote.ID
}
client := xrpcauth.NewAnonymousClient(ctx)
client.Host = u.String()
client.Client = util.RobustHTTPClient()
client.Client.Timeout = 30 * time.Minute
knownCursorBeforeFetch := remote.FirstCursorSinceReset
retry:
if p.limiter != nil {
if err := p.limiter.Wait(ctx, u.String()); err != nil {
return fmt.Errorf("failed to wait on rate limiter: %w", err)
}
}
// TODO: add a configuration knob for switching between full and partial fetch.
sinceRev := work.Repo.LastIndexedRev
b, err := comatproto.SyncGetRepo(ctx, client, work.Repo.DID, sinceRev)
if err != nil {
if err, ok := errors.As[*xrpc.Error](err); ok {
if err.IsThrottled() && err.Ratelimit != nil {
log.Debug().Str("pds", u.String()).Msgf("Hit a rate limit (%s), sleeping until %s", err.Ratelimit.Policy, err.Ratelimit.Reset)
time.Sleep(time.Until(err.Ratelimit.Reset))
goto retry
}
}
reposFetched.WithLabelValues(u.String(), "false").Inc()
return fmt.Errorf("failed to fetch repo: %w", err)
}
if len(b) == 0 {
reposFetched.WithLabelValues(u.String(), "false").Inc()
return fmt.Errorf("PDS returned zero bytes")
}
reposFetched.WithLabelValues(u.String(), "true").Inc()
if work.Repo.PDS == pds.Unknown {
remote, err := pds.EnsureExists(ctx, p.db, u.String())
if err != nil {
return err
}
work.Repo.PDS = remote.ID
if err := p.db.Model(&work.Repo).Where(&repo.Repo{ID: work.Repo.ID}).Updates(&repo.Repo{PDS: work.Repo.PDS}).Error; err != nil {
return fmt.Errorf("failed to set repo's PDS: %w", err)
}
}
newRev, err := repo.GetRev(ctx, bytes.NewReader(b))
if sinceRev != "" && errors.Is(err, repo.ErrZeroBlocks) {
// No new records since the rev we requested above.
if work.Repo.FirstCursorSinceReset < knownCursorBeforeFetch {
if err := p.bumpFirstCursorSinceReset(work.Repo.ID, knownCursorBeforeFetch); err != nil {
return fmt.Errorf("updating first_cursor_since_reset: %w", err)
}
}
return nil
} else if err != nil {
l := 25
if len(b) < l {
l = len(b)
}
log.Debug().Err(err).Msgf("Total bytes fetched: %d. First few bytes: %q", len(b), string(b[:l]))
return fmt.Errorf("failed to read 'rev' from the fetched repo: %w", err)
}
newRecs, err := repo.ExtractRecords(ctx, bytes.NewReader(b), pubKey)
if err != nil {
return fmt.Errorf("failed to extract records: %w", err)
}
recs := []repo.Record{}
for k, v := range newRecs {
parts := strings.SplitN(k, "/", 2)
if len(parts) != 2 {
log.Warn().Msgf("Unexpected key format: %q", k)
continue
}
v = regexp.MustCompile(`[^\\](\\\\)*(\\u0000)`).ReplaceAll(v, []byte(`$1<0x00>`))
recs = append(recs, repo.Record{
Repo: models.ID(work.Repo.ID),
Collection: parts[0],
Rkey: parts[1],
// XXX: proper replacement of \u0000 would require full parsing of JSON
// and recursive iteration over all string values, but this
// should work well enough for now.
Content: fix.EscapeNullCharForPostgres(v),
AtRev: newRev,
})
}
recordsFetched.Add(float64(len(recs)))
if len(recs) > 0 {
for _, batch := range splitInBatshes(recs, 500) {
result := p.db.Model(&repo.Record{}).
Clauses(clause.OnConflict{
Where: clause.Where{Exprs: []clause.Expression{
clause.Neq{
Column: clause.Column{Name: "content", Table: "records"},
Value: clause.Column{Name: "content", Table: "excluded"}},
clause.Or(
clause.Eq{Column: clause.Column{Name: "at_rev", Table: "records"}, Value: nil},
clause.Eq{Column: clause.Column{Name: "at_rev", Table: "records"}, Value: ""},
clause.Lt{
Column: clause.Column{Name: "at_rev", Table: "records"},
Value: clause.Column{Name: "at_rev", Table: "excluded"}},
)}},
DoUpdates: clause.AssignmentColumns([]string{"content", "at_rev"}),
Columns: []clause.Column{{Name: "repo"}, {Name: "collection"}, {Name: "rkey"}}}).
Create(batch)
if err := result.Error; err != nil {
return fmt.Errorf("inserting records into the database: %w", err)
}
recordsInserted.Add(float64(result.RowsAffected))
}
}
err = p.db.Model(&repo.Repo{}).Where(&repo.Repo{ID: work.Repo.ID}).
Updates(&repo.Repo{LastIndexedRev: newRev}).Error
if err != nil {
return fmt.Errorf("updating repo rev: %w", err)
}
if work.Repo.FirstCursorSinceReset < knownCursorBeforeFetch {
if err := p.bumpFirstCursorSinceReset(work.Repo.ID, knownCursorBeforeFetch); err != nil {
return fmt.Errorf("updating first_cursor_since_reset: %w", err)
}
}
// TODO: check for records that are missing in the repo download
// and mark them as deleted.
return nil
}
// bumpFirstCursorSinceReset increases repo's FirstCursorSinceReset iff it is currently lower than the supplied value.
func (p *WorkerPool) bumpFirstCursorSinceReset(repoId models.ID, cursorValue int64) error {
return p.db.Transaction(func(tx *gorm.DB) error {
var currentCursor int64
err := tx.Model(&repo.Repo{}).Where(&repo.Repo{ID: repoId}).
Select("first_cursor_since_reset").First(&currentCursor).Error
if err != nil {
return fmt.Errorf("failed to get current cursor value: %w", err)
}
if currentCursor < cursorValue {
return tx.Model(&repo.Repo{}).Where(&repo.Repo{ID: repoId}).
Updates(&repo.Repo{FirstCursorSinceReset: cursorValue}).Error
}
return nil
})
}
func splitInBatshes[T any](s []T, batchSize int) [][]T {
var r [][]T
for i := 0; i < len(s); i += batchSize {
if i+batchSize < len(s) {
r = append(r, s[i:i+batchSize])
} else {
r = append(r, s[i:])
}
}
return r
}

View file

@ -1,14 +0,0 @@
FROM golang:1.22.3 as builder
WORKDIR /app
COPY go.mod go.sum ./
RUN go mod download
COPY . ./
RUN go build -trimpath ./cmd/update-db-schema
FROM alpine:latest as certs
RUN apk --update add ca-certificates
FROM debian:stable-slim
COPY --from=certs /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ca-certificates.crt
COPY --from=builder /app/update-db-schema .
ENTRYPOINT ["./update-db-schema"]

View file

@ -1,7 +0,0 @@
*
**/*
!go.mod
!go.sum
!**/*.go
cmd/**
!cmd/update-db-schema

View file

@ -1,155 +0,0 @@
package main
import (
"context"
"flag"
"fmt"
"io"
"log"
_ "net/http/pprof"
"os"
"os/signal"
"path/filepath"
"runtime"
"runtime/debug"
"strings"
"syscall"
"time"
_ "github.com/joho/godotenv/autoload"
"github.com/kelseyhightower/envconfig"
"github.com/rs/zerolog"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"github.com/uabluerail/indexer/pds"
"github.com/uabluerail/indexer/repo"
"github.com/uabluerail/indexer/util/gormzerolog"
)
type Config struct {
LogFile string
LogFormat string `default:"text"`
LogLevel int64 `default:"1"`
DBUrl string `envconfig:"POSTGRES_URL"`
}
var config Config
func runMain(ctx context.Context) error {
ctx = setupLogging(ctx)
log := zerolog.Ctx(ctx)
log.Debug().Msgf("Starting up...")
db, err := gorm.Open(postgres.Open(config.DBUrl), &gorm.Config{
Logger: gormzerolog.New(&logger.Config{
SlowThreshold: 1 * time.Second,
IgnoreRecordNotFoundError: true,
}, nil),
})
if err != nil {
return fmt.Errorf("connecting to the database: %w", err)
}
log.Debug().Msgf("DB connection established")
for _, f := range []func(*gorm.DB) error{
pds.AutoMigrate,
repo.AutoMigrate,
} {
if err := f(db); err != nil {
return fmt.Errorf("auto-migrating DB schema: %w", err)
}
}
log.Debug().Msgf("DB schema updated")
return nil
}
func main() {
flag.StringVar(&config.LogFile, "log", "", "Path to the log file. If empty, will log to stderr")
flag.StringVar(&config.LogFormat, "log-format", "text", "Logging format. 'text' or 'json'")
flag.Int64Var(&config.LogLevel, "log-level", 1, "Log level. -1 - trace, 0 - debug, 1 - info, 5 - panic")
if err := envconfig.Process("update-db-schema", &config); err != nil {
log.Fatalf("envconfig.Process: %s", err)
}
flag.Parse()
ctx, _ := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
if err := runMain(ctx); err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
}
func setupLogging(ctx context.Context) context.Context {
logFile := os.Stderr
if config.LogFile != "" {
f, err := os.OpenFile(config.LogFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
log.Fatalf("Failed to open the specified log file %q: %s", config.LogFile, err)
}
logFile = f
}
var output io.Writer
switch config.LogFormat {
case "json":
output = logFile
case "text":
prefixList := []string{}
info, ok := debug.ReadBuildInfo()
if ok {
prefixList = append(prefixList, info.Path+"/")
}
basedir := ""
_, sourceFile, _, ok := runtime.Caller(0)
if ok {
basedir = filepath.Dir(sourceFile)
}
if basedir != "" && strings.HasPrefix(basedir, "/") {
prefixList = append(prefixList, basedir+"/")
head, _ := filepath.Split(basedir)
for head != "/" {
prefixList = append(prefixList, head)
head, _ = filepath.Split(strings.TrimSuffix(head, "/"))
}
}
output = zerolog.ConsoleWriter{
Out: logFile,
NoColor: true,
TimeFormat: time.RFC3339,
PartsOrder: []string{
zerolog.LevelFieldName,
zerolog.TimestampFieldName,
zerolog.CallerFieldName,
zerolog.MessageFieldName,
},
FormatFieldName: func(i interface{}) string { return fmt.Sprintf("%s:", i) },
FormatFieldValue: func(i interface{}) string { return fmt.Sprintf("%s", i) },
FormatCaller: func(i interface{}) string {
s := i.(string)
for _, p := range prefixList {
s = strings.TrimPrefix(s, p)
}
return s
},
}
default:
log.Fatalf("Invalid log format specified: %q", config.LogFormat)
}
logger := zerolog.New(output).Level(zerolog.Level(config.LogLevel)).With().Caller().Timestamp().Logger()
ctx = logger.WithContext(ctx)
zerolog.DefaultContextLogger = &logger
log.SetOutput(logger)
return ctx
}