Keep the set of running consumers up to date

main
Max Ignatenko 2024-03-28 20:02:48 +00:00
parent 337f3ef2b8
commit c919050833
2 changed files with 176 additions and 75 deletions

View File

@ -14,6 +14,7 @@ import (
"time" "time"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/prometheus/client_golang/prometheus"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
@ -43,6 +44,7 @@ type BadRecord struct {
type Consumer struct { type Consumer struct {
db *gorm.DB db *gorm.DB
remote pds.PDS remote pds.PDS
running chan struct{}
lastCursorPersist time.Time lastCursorPersist time.Time
} }
@ -55,6 +57,7 @@ func NewConsumer(ctx context.Context, remote *pds.PDS, db *gorm.DB) (*Consumer,
return &Consumer{ return &Consumer{
db: db, db: db,
remote: *remote, remote: *remote,
running: make(chan struct{}),
}, nil }, nil
} }
@ -63,16 +66,41 @@ func (c *Consumer) Start(ctx context.Context) error {
return nil return nil
} }
func (c *Consumer) Wait(ctx context.Context) error {
select {
case <-ctx.Done():
return ctx.Err()
case <-c.running:
// Channel got closed
return nil
}
}
func (c *Consumer) run(ctx context.Context) { func (c *Consumer) run(ctx context.Context) {
log := zerolog.Ctx(ctx).With().Str("pds", c.remote.Host).Logger() log := zerolog.Ctx(ctx).With().Str("pds", c.remote.Host).Logger()
ctx = log.WithContext(ctx) ctx = log.WithContext(ctx)
defer close(c.running)
for { for {
select {
case <-c.running:
log.Error().Msgf("Attempt to start previously stopped consumer")
return
case <-ctx.Done():
log.Info().Msgf("Consumer stopped")
lastEventTimestamp.DeletePartialMatch(prometheus.Labels{"remote": c.remote.Host})
eventCounter.DeletePartialMatch(prometheus.Labels{"remote": c.remote.Host})
reposDiscovered.DeletePartialMatch(prometheus.Labels{"remote": c.remote.Host})
postsByLanguageIndexed.DeletePartialMatch(prometheus.Labels{"remote": c.remote.Host})
return
default:
if err := c.runOnce(ctx); err != nil { if err := c.runOnce(ctx); err != nil {
log.Error().Err(err).Msgf("Consumer of %q failed (will be restarted): %s", c.remote.Host, err) log.Error().Err(err).Msgf("Consumer of %q failed (will be restarted): %s", c.remote.Host, err)
} }
time.Sleep(time.Second) time.Sleep(time.Second)
} }
}
} }
func (c *Consumer) runOnce(ctx context.Context) error { func (c *Consumer) runOnce(ctx context.Context) error {
@ -120,6 +148,10 @@ func (c *Consumer) runOnce(ctx context.Context) error {
first := true first := true
for { for {
select {
case <-ctx.Done():
return ctx.Err()
default:
_, b, err := conn.ReadMessage() _, b, err := conn.ReadMessage()
if err != nil { if err != nil {
return fmt.Errorf("websocket.ReadMessage: %w", err) return fmt.Errorf("websocket.ReadMessage: %w", err)
@ -174,6 +206,7 @@ func (c *Consumer) runOnce(ctx context.Context) error {
} }
first = false first = false
} }
}
} }
func (c *Consumer) checkForCursorReset(ctx context.Context, seq int64) error { func (c *Consumer) checkForCursorReset(ctx context.Context, seq int64) error {

View File

@ -54,23 +54,8 @@ func runMain(ctx context.Context) error {
} }
log.Debug().Msgf("DB connection established") log.Debug().Msgf("DB connection established")
remotes := []pds.PDS{} consumersCh := make(chan struct{})
if err := db.Find(&remotes).Error; err != nil { go runConsumers(ctx, db, consumersCh)
return fmt.Errorf("listing known PDSs: %w", err)
}
// TODO: check for changes and start/stop consumers as needed
for _, remote := range remotes {
if remote.Disabled {
continue
}
c, err := NewConsumer(ctx, &remote, db)
if err != nil {
return fmt.Errorf("failed to create a consumer for %q: %w", remote.Host, err)
}
if err := c.Start(ctx); err != nil {
return fmt.Errorf("failed ot start a consumer for %q: %w", remote.Host, err)
}
}
log.Info().Msgf("Starting HTTP listener on %q...", config.MetricsPort) log.Info().Msgf("Starting HTTP listener on %q...", config.MetricsPort)
http.Handle("/metrics", promhttp.Handler()) http.Handle("/metrics", promhttp.Handler())
@ -85,9 +70,92 @@ func runMain(ctx context.Context) error {
return fmt.Errorf("HTTP server shutdown failed: %w", err) return fmt.Errorf("HTTP server shutdown failed: %w", err)
} }
} }
log.Info().Msgf("Waiting for consumers to stop...")
<-consumersCh
return <-errCh return <-errCh
} }
func runConsumers(ctx context.Context, db *gorm.DB, doneCh chan struct{}) {
log := zerolog.Ctx(ctx)
defer close(doneCh)
type consumerHandle struct {
cancel context.CancelFunc
consumer *Consumer
}
running := map[string]consumerHandle{}
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
t := make(chan time.Time, 1)
t <- time.Now()
for {
select {
case <-t:
remotes := []pds.PDS{}
if err := db.Find(&remotes).Error; err != nil {
log.Error().Err(err).Msgf("Failed to get a list of known PDSs: %s", err)
break
}
shouldBeRunning := map[string]pds.PDS{}
for _, remote := range remotes {
if remote.Disabled {
continue
}
shouldBeRunning[remote.Host] = remote
}
for host, handle := range running {
if _, found := shouldBeRunning[host]; found {
continue
}
handle.cancel()
_ = handle.consumer.Wait(ctx)
delete(running, host)
}
for host, remote := range shouldBeRunning {
if _, found := running[host]; found {
continue
}
subCtx, cancel := context.WithCancel(ctx)
c, err := NewConsumer(subCtx, &remote, db)
if err != nil {
log.Error().Err(err).Msgf("Failed to create a consumer for %q: %s", remote.Host, err)
continue
}
if err := c.Start(ctx); err != nil {
log.Error().Err(err).Msgf("Failed ot start a consumer for %q: %s", remote.Host, err)
continue
}
running[host] = consumerHandle{
cancel: cancel,
consumer: c,
}
}
case <-ctx.Done():
for host, handle := range running {
handle.cancel()
_ = handle.consumer.Wait(ctx)
delete(running, host)
}
case v := <-ticker.C:
// Non-blocking send.
select {
case t <- v:
default:
}
}
}
}
func main() { func main() {
flag.StringVar(&config.LogFile, "log", "", "Path to the log file. If empty, will log to stderr") flag.StringVar(&config.LogFile, "log", "", "Path to the log file. If empty, will log to stderr")
flag.StringVar(&config.LogFormat, "log-format", "text", "Logging format. 'text' or 'json'") flag.StringVar(&config.LogFormat, "log-format", "text", "Logging format. 'text' or 'json'")