This commit is contained in:
Max Ignatenko 2024-02-15 16:10:39 +00:00
parent 2b6abac607
commit 63a767d890
25 changed files with 3027 additions and 0 deletions

14
cmd/consumer/Dockerfile Normal file
View file

@ -0,0 +1,14 @@
FROM golang:1.21.1 as builder
WORKDIR /app
COPY go.mod go.sum ./
RUN go mod download
COPY . ./
RUN go build -trimpath ./cmd/consumer
FROM alpine:latest as certs
RUN apk --update add ca-certificates
FROM debian:stable-slim
COPY --from=certs /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ca-certificates.crt
COPY --from=builder /app/consumer .
ENTRYPOINT ["./consumer"]

431
cmd/consumer/consumer.go Normal file
View file

@ -0,0 +1,431 @@
package main
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"math"
"net/http"
"net/url"
"path"
"strings"
"time"
"github.com/gorilla/websocket"
"github.com/rs/zerolog"
"gorm.io/gorm"
"gorm.io/gorm/clause"
comatproto "github.com/bluesky-social/indigo/api/atproto"
"github.com/bluesky-social/indigo/xrpc"
"github.com/ipld/go-ipld-prime/codec/dagcbor"
"github.com/ipld/go-ipld-prime/datamodel"
"github.com/ipld/go-ipld-prime/node/basicnode"
"github.com/uabluerail/indexer/models"
"github.com/uabluerail/indexer/pds"
"github.com/uabluerail/indexer/repo"
)
type Consumer struct {
db *gorm.DB
remote pds.PDS
lastCursorPersist time.Time
}
func NewConsumer(ctx context.Context, remote *pds.PDS, db *gorm.DB) (*Consumer, error) {
return &Consumer{
db: db,
remote: *remote,
}, nil
}
func (c *Consumer) Start(ctx context.Context) error {
go c.run(ctx)
return nil
}
func (c *Consumer) run(ctx context.Context) {
log := zerolog.Ctx(ctx).With().Str("pds", c.remote.Host).Logger()
ctx = log.WithContext(ctx)
for {
if err := c.runOnce(ctx); err != nil {
log.Error().Err(err).Msgf("Consumer of %q failed (will be restarted): %s", c.remote.Host, err)
}
time.Sleep(time.Second)
}
}
func (c *Consumer) runOnce(ctx context.Context) error {
log := zerolog.Ctx(ctx)
addr, err := url.Parse(c.remote.Host)
if err != nil {
return fmt.Errorf("parsing URL %q: %s", c.remote.Host, err)
}
addr.Scheme = "wss"
addr.Path = path.Join(addr.Path, "xrpc/com.atproto.sync.subscribeRepos")
if c.remote.Cursor > 0 {
params := url.Values{"cursor": []string{fmt.Sprint(c.remote.Cursor)}}
addr.RawQuery = params.Encode()
}
conn, _, err := websocket.DefaultDialer.DialContext(ctx, addr.String(), http.Header{})
if err != nil {
return fmt.Errorf("establishing websocker connection: %w", err)
}
defer conn.Close()
ch := make(chan bool)
defer close(ch)
go func() {
t := time.NewTicker(time.Minute)
for {
select {
case <-ch:
return
case <-t.C:
if err := conn.WriteControl(websocket.PingMessage, []byte("ping"), time.Now().Add(time.Minute)); err != nil {
log.Error().Err(err).Msgf("Failed to send ping: %s", err)
}
}
}
}()
first := true
for {
_, b, err := conn.ReadMessage()
if err != nil {
return fmt.Errorf("websocket.ReadMessage: %w", err)
}
r := bytes.NewReader(b)
proto := basicnode.Prototype.Any
headerNode := proto.NewBuilder()
if err := (&dagcbor.DecodeOptions{DontParseBeyondEnd: true}).Decode(headerNode, r); err != nil {
return fmt.Errorf("unmarshaling message header: %w", err)
}
header, err := parseHeader(headerNode.Build())
if err != nil {
return fmt.Errorf("parsing message header: %w", err)
}
switch header.Op {
case 1:
if err := c.processMessage(ctx, header.Type, r, first); err != nil {
return err
}
case -1:
bodyNode := proto.NewBuilder()
if err := (&dagcbor.DecodeOptions{DontParseBeyondEnd: true, AllowLinks: true}).Decode(bodyNode, r); err != nil {
return fmt.Errorf("unmarshaling message body: %w", err)
}
body, err := parseError(bodyNode.Build())
if err != nil {
return fmt.Errorf("parsing error payload: %w", err)
}
return &body
default:
log.Warn().Msgf("Unknown 'op' value received: %d", header.Op)
}
first = false
}
}
func (c *Consumer) checkForCursorReset(ctx context.Context, seq int64) error {
// hack to detect cursor resets upon connection for implementations
// that don't emit an explicit #info when connecting with an outdated cursor.
if seq == c.remote.Cursor+1 {
// No reset.
return nil
}
return c.resetCursor(ctx, seq)
}
func (c *Consumer) resetCursor(ctx context.Context, seq int64) error {
zerolog.Ctx(ctx).Warn().Str("pds", c.remote.Host).Msgf("Cursor reset: %d -> %d", c.remote.Cursor, seq)
err := c.db.Model(&c.remote).
Where(&pds.PDS{Model: gorm.Model{ID: c.remote.ID}}).
Updates(&pds.PDS{FirstCursorSinceReset: seq}).Error
if err != nil {
return fmt.Errorf("updating FirstCursorSinceReset: %w", err)
}
c.remote.FirstCursorSinceReset = seq
return nil
}
func (c *Consumer) updateCursor(ctx context.Context, seq int64) error {
if math.Abs(float64(seq-c.remote.Cursor)) < 100 && time.Since(c.lastCursorPersist) < 5*time.Second {
c.remote.Cursor = seq
return nil
}
err := c.db.Model(&c.remote).
Where(&pds.PDS{Model: gorm.Model{ID: c.remote.ID}}).
Updates(&pds.PDS{Cursor: seq}).Error
if err != nil {
return fmt.Errorf("updating Cursor: %w", err)
}
c.remote.Cursor = seq
return nil
}
func (c *Consumer) processMessage(ctx context.Context, typ string, r io.Reader, first bool) error {
log := zerolog.Ctx(ctx)
switch typ {
case "#commit":
payload := &comatproto.SyncSubscribeRepos_Commit{}
if err := payload.UnmarshalCBOR(r); err != nil {
return fmt.Errorf("failed to unmarshal commit: %w", err)
}
if c.remote.FirstCursorSinceReset == 0 {
if err := c.resetCursor(ctx, payload.Seq); err != nil {
return fmt.Errorf("handling cursor reset: %w", err)
}
}
if first {
if err := c.checkForCursorReset(ctx, payload.Seq); err != nil {
return err
}
}
repoInfo, err := repo.EnsureExists(ctx, c.db, payload.Repo)
if err != nil {
return fmt.Errorf("repo.EnsureExists(%q): %w", payload.Repo, err)
}
if repoInfo.PDS != models.ID(c.remote.ID) {
log.Error().Str("did", payload.Repo).Str("rev", payload.Rev).
Msgf("Commit from an incorrect PDS, skipping")
return nil
}
// TODO: verify signature
expectRecords := false
deletions := []string{}
for _, op := range payload.Ops {
switch op.Action {
case "create":
expectRecords = true
case "update":
expectRecords = true
case "delete":
deletions = append(deletions, op.Path)
}
}
for _, d := range deletions {
parts := strings.SplitN(d, "/", 2)
if len(parts) != 2 {
continue
}
err := c.db.Model(&repo.Record{}).
Where(&repo.Record{
Repo: models.ID(repoInfo.ID),
Collection: parts[0],
Rkey: parts[1]}).
Updates(&repo.Record{Deleted: true}).Error
if err != nil {
return fmt.Errorf("failed to mark %s/%s as deleted: %w", payload.Repo, d, err)
}
}
newRecs, err := repo.ExtractRecords(ctx, bytes.NewReader(payload.Blocks))
if err != nil {
return fmt.Errorf("failed to extract records: %w", err)
}
recs := []repo.Record{}
for k, v := range newRecs {
parts := strings.SplitN(k, "/", 2)
if len(parts) != 2 {
log.Warn().Msgf("Unexpected key format: %q", k)
continue
}
recs = append(recs, repo.Record{
Repo: models.ID(repoInfo.ID),
Collection: parts[0],
Rkey: parts[1],
Content: v,
})
}
if len(recs) == 0 && expectRecords {
log.Debug().Int64("seq", payload.Seq).Str("pds", c.remote.Host).Msgf("len(recs) == 0")
}
if len(recs) > 0 || expectRecords {
err = c.db.Model(&repo.Record{}).
Clauses(clause.OnConflict{DoUpdates: clause.AssignmentColumns([]string{"content"}),
Columns: []clause.Column{{Name: "repo"}, {Name: "collection"}, {Name: "rkey"}}}).
Create(recs).Error
if err != nil {
return fmt.Errorf("inserting records into the database: %w", err)
}
}
if payload.TooBig {
// Just trigger a re-index by resetting rev.
err := c.db.Model(r).Where(&repo.Repo{Model: gorm.Model{ID: repoInfo.ID}}).
Updates(&repo.Repo{
FirstCursorSinceReset: c.remote.FirstCursorSinceReset,
FirstRevSinceReset: payload.Rev,
}).Error
if err != nil {
return fmt.Errorf("failed to update repo info after cursor reset: %w", err)
}
}
if repoInfo.FirstCursorSinceReset != c.remote.FirstCursorSinceReset {
err := c.db.Model(r).Where(&repo.Repo{Model: gorm.Model{ID: repoInfo.ID}}).
Updates(&repo.Repo{
FirstCursorSinceReset: c.remote.FirstCursorSinceReset,
FirstRevSinceReset: payload.Rev,
}).Error
if err != nil {
return fmt.Errorf("failed to update repo info after cursor reset: %w", err)
}
}
if err := c.updateCursor(ctx, payload.Seq); err != nil {
return err
}
case "#handle":
payload := &comatproto.SyncSubscribeRepos_Handle{}
if err := payload.UnmarshalCBOR(r); err != nil {
return fmt.Errorf("failed to unmarshal commit: %w", err)
}
if c.remote.FirstCursorSinceReset == 0 {
if err := c.resetCursor(ctx, payload.Seq); err != nil {
return fmt.Errorf("handling cursor reset: %w", err)
}
}
if first {
if err := c.checkForCursorReset(ctx, payload.Seq); err != nil {
return err
}
}
// No-op, we don't store handles.
if err := c.updateCursor(ctx, payload.Seq); err != nil {
return err
}
case "#migrate":
payload := &comatproto.SyncSubscribeRepos_Migrate{}
if err := payload.UnmarshalCBOR(r); err != nil {
return fmt.Errorf("failed to unmarshal commit: %w", err)
}
if c.remote.FirstCursorSinceReset == 0 {
if err := c.resetCursor(ctx, payload.Seq); err != nil {
return fmt.Errorf("handling cursor reset: %w", err)
}
}
if first {
if err := c.checkForCursorReset(ctx, payload.Seq); err != nil {
return err
}
}
log.Debug().Interface("payload", payload).Str("did", payload.Did).Msgf("MIGRATION")
// TODO
if err := c.updateCursor(ctx, payload.Seq); err != nil {
return err
}
case "#tombstone":
payload := &comatproto.SyncSubscribeRepos_Tombstone{}
if err := payload.UnmarshalCBOR(r); err != nil {
return fmt.Errorf("failed to unmarshal commit: %w", err)
}
if c.remote.FirstCursorSinceReset == 0 {
if err := c.resetCursor(ctx, payload.Seq); err != nil {
return fmt.Errorf("handling cursor reset: %w", err)
}
}
if first {
if err := c.checkForCursorReset(ctx, payload.Seq); err != nil {
return err
}
}
// TODO
if err := c.updateCursor(ctx, payload.Seq); err != nil {
return err
}
case "#info":
payload := &comatproto.SyncSubscribeRepos_Info{}
if err := payload.UnmarshalCBOR(r); err != nil {
return fmt.Errorf("failed to unmarshal commit: %w", err)
}
switch payload.Name {
case "OutdatedCursor":
if !first {
log.Warn().Msgf("Received cursor reset notification in the middle of a stream: %+v", payload)
}
c.remote.FirstCursorSinceReset = 0
default:
log.Error().Msgf("Unknown #info message %q: %+v", payload.Name, payload)
}
default:
log.Warn().Msgf("Unknown message type received: %s", typ)
}
return nil
}
type Header struct {
Op int64
Type string
}
func parseHeader(node datamodel.Node) (Header, error) {
r := Header{}
op, err := node.LookupByString("op")
if err != nil {
return r, fmt.Errorf("missing 'op': %w", err)
}
r.Op, err = op.AsInt()
if err != nil {
return r, fmt.Errorf("op.AsInt(): %w", err)
}
if r.Op == -1 {
// Error frame, type should not be present
return r, nil
}
t, err := node.LookupByString("t")
if err != nil {
return r, fmt.Errorf("missing 't': %w", err)
}
r.Type, err = t.AsString()
if err != nil {
return r, fmt.Errorf("t.AsString(): %w", err)
}
return r, nil
}
func parseError(node datamodel.Node) (xrpc.XRPCError, error) {
r := xrpc.XRPCError{}
e, err := node.LookupByString("error")
if err != nil {
return r, fmt.Errorf("missing 'error': %w", err)
}
r.ErrStr, err = e.AsString()
if err != nil {
return r, fmt.Errorf("error.AsString(): %w", err)
}
m, err := node.LookupByString("message")
if err == nil {
r.Message, err = m.AsString()
if err != nil {
return r, fmt.Errorf("message.AsString(): %w", err)
}
} else if !errors.Is(err, datamodel.ErrNotExists{}) {
return r, fmt.Errorf("looking up 'message': %w", err)
}
return r, nil
}

176
cmd/consumer/main.go Normal file
View file

@ -0,0 +1,176 @@
package main
import (
"context"
"flag"
"fmt"
"io"
"log"
"net/http"
_ "net/http/pprof"
"os"
"os/signal"
"path/filepath"
"runtime"
"runtime/debug"
"strings"
"syscall"
"time"
_ "github.com/joho/godotenv/autoload"
"github.com/kelseyhightower/envconfig"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/rs/zerolog"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"github.com/uabluerail/indexer/pds"
"github.com/uabluerail/indexer/util/gormzerolog"
)
type Config struct {
LogFile string
LogFormat string `default:"text"`
LogLevel int64 `default:"1"`
MetricsPort string `split_words:"true"`
DBUrl string `envconfig:"POSTGRES_URL"`
}
var config Config
func runMain(ctx context.Context) error {
ctx = setupLogging(ctx)
log := zerolog.Ctx(ctx)
log.Debug().Msgf("Starting up...")
db, err := gorm.Open(postgres.Open(config.DBUrl), &gorm.Config{
Logger: gormzerolog.New(&logger.Config{
SlowThreshold: 3 * time.Second,
IgnoreRecordNotFoundError: true,
}, nil),
})
if err != nil {
return fmt.Errorf("connecting to the database: %w", err)
}
log.Debug().Msgf("DB connection established")
remotes := []pds.PDS{}
if err := db.Find(&remotes).Error; err != nil {
return fmt.Errorf("listing known PDSs: %w", err)
}
// TODO: check for changes and start/stop consumers as needed
for _, remote := range remotes {
c, err := NewConsumer(ctx, &remote, db)
if err != nil {
return fmt.Errorf("failed to create a consumer for %q: %w", remote.Host, err)
}
if err := c.Start(ctx); err != nil {
return fmt.Errorf("failed ot start a consumer for %q: %w", remote.Host, err)
}
}
log.Info().Msgf("Starting HTTP listener on %q...", config.MetricsPort)
http.Handle("/metrics", promhttp.Handler())
srv := &http.Server{Addr: fmt.Sprintf(":%s", config.MetricsPort)}
errCh := make(chan error)
go func() {
errCh <- srv.ListenAndServe()
}()
select {
case <-ctx.Done():
if err := srv.Shutdown(context.Background()); err != nil {
return fmt.Errorf("HTTP server shutdown failed: %w", err)
}
}
return <-errCh
}
func main() {
flag.StringVar(&config.LogFile, "log", "", "Path to the log file. If empty, will log to stderr")
flag.StringVar(&config.LogFormat, "log-format", "text", "Logging format. 'text' or 'json'")
flag.Int64Var(&config.LogLevel, "log-level", 1, "Log level. -1 - trace, 0 - debug, 1 - info, 5 - panic")
if err := envconfig.Process("consumer", &config); err != nil {
log.Fatalf("envconfig.Process: %s", err)
}
flag.Parse()
ctx, _ := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
if err := runMain(ctx); err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
}
func setupLogging(ctx context.Context) context.Context {
logFile := os.Stderr
if config.LogFile != "" {
f, err := os.OpenFile(config.LogFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
log.Fatalf("Failed to open the specified log file %q: %s", config.LogFile, err)
}
logFile = f
}
var output io.Writer
switch config.LogFormat {
case "json":
output = logFile
case "text":
prefixList := []string{}
info, ok := debug.ReadBuildInfo()
if ok {
prefixList = append(prefixList, info.Path+"/")
}
basedir := ""
_, sourceFile, _, ok := runtime.Caller(0)
if ok {
basedir = filepath.Dir(sourceFile)
}
if basedir != "" && strings.HasPrefix(basedir, "/") {
prefixList = append(prefixList, basedir+"/")
head, _ := filepath.Split(basedir)
for head != "/" {
prefixList = append(prefixList, head)
head, _ = filepath.Split(strings.TrimSuffix(head, "/"))
}
}
output = zerolog.ConsoleWriter{
Out: logFile,
NoColor: true,
TimeFormat: time.RFC3339,
PartsOrder: []string{
zerolog.LevelFieldName,
zerolog.TimestampFieldName,
zerolog.CallerFieldName,
zerolog.MessageFieldName,
},
FormatFieldName: func(i interface{}) string { return fmt.Sprintf("%s:", i) },
FormatFieldValue: func(i interface{}) string { return fmt.Sprintf("%s", i) },
FormatCaller: func(i interface{}) string {
s := i.(string)
for _, p := range prefixList {
s = strings.TrimPrefix(s, p)
}
return s
},
}
default:
log.Fatalf("Invalid log format specified: %q", config.LogFormat)
}
logger := zerolog.New(output).Level(zerolog.Level(config.LogLevel)).With().Caller().Timestamp().Logger()
ctx = logger.WithContext(ctx)
zerolog.DefaultContextLogger = &logger
log.SetOutput(logger)
return ctx
}