SQLite cache

pull/12/head
Philipp Heckel 2021-11-02 21:09:49 -04:00
parent 1c7695c1f3
commit 7b810acfb5
5 changed files with 254 additions and 133 deletions

View File

@ -1,61 +1,14 @@
package server package server
import ( import (
"database/sql"
"time"
_ "github.com/mattn/go-sqlite3" // SQLite driver _ "github.com/mattn/go-sqlite3" // SQLite driver
"time"
) )
const ( type cache interface {
createTableQuery = `CREATE TABLE IF NOT EXISTS messages ( AddMessage(m *message) error
id VARCHAR(20) PRIMARY KEY, Messages(topic string, since time.Time) ([]*message, error)
time INT NOT NULL, MessageCount(topic string) (int, error)
topic VARCHAR(64) NOT NULL, Topics() (map[string]*topic, error)
message VARCHAR(1024) NOT NULL Prune(keep time.Duration) error
)`
insertQuery = `INSERT INTO messages (id, time, topic, message) VALUES (?, ?, ?, ?)`
pruneOlderThanQuery = `DELETE FROM messages WHERE time < ?`
)
type cache struct {
db *sql.DB
insert *sql.Stmt
prune *sql.Stmt
}
func newCache(filename string) (*cache, error) {
db, err := sql.Open("sqlite3", filename)
if err != nil {
return nil, err
}
if _, err := db.Exec(createTableQuery); err != nil {
return nil, err
}
insert, err := db.Prepare(insertQuery)
if err != nil {
return nil, err
}
prune, err := db.Prepare(pruneOlderThanQuery)
if err != nil {
return nil, err
}
return &cache{
db: db,
insert: insert,
prune: prune,
}, nil
}
func (c *cache) Load() (map[string]*topic, error) {
}
func (c *cache) Add(m *message) error {
_, err := c.insert.Exec(m.ID, m.Time, m.Topic, m.Message)
return err
}
func (c *cache) Prune(olderThan time.Duration) error {
_, err := c.prune.Exec(time.Now().Add(-1 * olderThan).Unix())
return err
} }

View File

@ -0,0 +1,80 @@
package server
import (
_ "github.com/mattn/go-sqlite3" // SQLite driver
"sync"
"time"
)
type memCache struct {
messages map[string][]*message
mu sync.Mutex
}
var _ cache = (*memCache)(nil)
func newMemCache() *memCache {
return &memCache{
messages: make(map[string][]*message),
}
}
func (s *memCache) AddMessage(m *message) error {
s.mu.Lock()
defer s.mu.Unlock()
if _, ok := s.messages[m.Topic]; !ok {
s.messages[m.Topic] = make([]*message, 0)
}
s.messages[m.Topic] = append(s.messages[m.Topic], m)
return nil
}
func (s *memCache) Messages(topic string, since time.Time) ([]*message, error) {
s.mu.Lock()
defer s.mu.Unlock()
if _, ok := s.messages[topic]; !ok {
return make([]*message, 0), nil
}
messages := make([]*message, 0) // copy!
for _, m := range s.messages[topic] {
msgTime := time.Unix(m.Time, 0)
if msgTime == since || msgTime.After(since) {
messages = append(messages, m)
}
}
return messages, nil
}
func (s *memCache) MessageCount(topic string) (int, error) {
s.mu.Lock()
defer s.mu.Unlock()
if _, ok := s.messages[topic]; !ok {
return 0, nil
}
return len(s.messages[topic]), nil
}
func (s *memCache) Topics() (map[string]*topic, error) {
// Hack since we know when this is called there are no messages!
return make(map[string]*topic), nil
}
func (s *memCache) Prune(keep time.Duration) error {
s.mu.Lock()
defer s.mu.Unlock()
for topic, _ := range s.messages {
s.pruneTopic(topic, keep)
}
return nil
}
func (s *memCache) pruneTopic(topic string, keep time.Duration) {
for i, m := range s.messages[topic] {
msgTime := time.Unix(m.Time, 0)
if time.Since(msgTime) < keep {
s.messages[topic] = s.messages[topic][i:]
return
}
}
s.messages[topic] = make([]*message, 0) // all messages expired
}

View File

@ -0,0 +1,127 @@
package server
import (
"database/sql"
"errors"
_ "github.com/mattn/go-sqlite3" // SQLite driver
"time"
)
const (
createTableQuery = `
BEGIN;
CREATE TABLE IF NOT EXISTS messages (
id VARCHAR(20) PRIMARY KEY,
time INT NOT NULL,
topic VARCHAR(64) NOT NULL,
message VARCHAR(1024) NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
COMMIT;
`
insertMessageQuery = `INSERT INTO messages (id, time, topic, message) VALUES (?, ?, ?, ?)`
pruneMessagesQuery = `DELETE FROM messages WHERE time < ?`
selectMessagesSinceTimeQuery = `
SELECT id, time, message
FROM messages
WHERE topic = ? AND time >= ?
ORDER BY time ASC
`
selectMessageCountQuery = `SELECT count(*) FROM messages WHERE topic = ?`
selectTopicsQuery = `SELECT topic, MAX(time) FROM messages GROUP BY TOPIC`
)
type sqliteCache struct {
db *sql.DB
}
var _ cache = (*sqliteCache)(nil)
func newSqliteCache(filename string) (*sqliteCache, error) {
db, err := sql.Open("sqlite3", filename)
if err != nil {
return nil, err
}
if _, err := db.Exec(createTableQuery); err != nil {
return nil, err
}
return &sqliteCache{
db: db,
}, nil
}
func (c *sqliteCache) AddMessage(m *message) error {
_, err := c.db.Exec(insertMessageQuery, m.ID, m.Time, m.Topic, m.Message)
return err
}
func (c *sqliteCache) Messages(topic string, since time.Time) ([]*message, error) {
rows, err := c.db.Query(selectMessagesSinceTimeQuery, topic, since.Unix())
if err != nil {
return nil, err
}
defer rows.Close()
messages := make([]*message, 0)
for rows.Next() {
var timestamp int64
var id, msg string
if err := rows.Scan(&id, &timestamp, &msg); err != nil {
return nil, err
}
messages = append(messages, &message{
ID: id,
Time: timestamp,
Event: messageEvent,
Topic: topic,
Message: msg,
})
}
if err := rows.Err(); err != nil {
return nil, err
}
return messages, nil
}
func (c *sqliteCache) MessageCount(topic string) (int, error) {
rows, err := c.db.Query(selectMessageCountQuery, topic)
if err != nil {
return 0, err
}
defer rows.Close()
var count int
if !rows.Next() {
return 0, errors.New("no rows found")
}
if err := rows.Scan(&count); err != nil {
return 0, err
} else if err := rows.Err(); err != nil {
return 0, err
}
return count, nil
}
func (s *sqliteCache) Topics() (map[string]*topic, error) {
rows, err := s.db.Query(selectTopicsQuery)
if err != nil {
return nil, err
}
defer rows.Close()
topics := make(map[string]*topic, 0)
for rows.Next() {
var id string
var last int64
if err := rows.Scan(&id, &last); err != nil {
return nil, err
}
topics[id] = newTopic(id, time.Unix(last, 0))
}
if err := rows.Err(); err != nil {
return nil, err
}
return topics, nil
}
func (c *sqliteCache) Prune(keep time.Duration) error {
_, err := c.db.Exec(pruneMessagesQuery, time.Now().Add(-1 * keep).Unix())
return err
}

View File

@ -32,7 +32,7 @@ type Server struct {
visitors map[string]*visitor visitors map[string]*visitor
firebase subscriber firebase subscriber
messages int64 messages int64
cache *cache cache cache
mu sync.Mutex mu sync.Mutex
} }
@ -78,30 +78,28 @@ func New(conf *config.Config) (*Server, error) {
return nil, err return nil, err
} }
} }
cache, err := maybeCreateCache(conf) cache, err := createCache(conf)
if err != nil { if err != nil {
return nil, err return nil, err
} }
topics := make(map[string]*topic) topics, err := cache.Topics()
if cache != nil { if err != nil {
if topics, err = cache.Load(); err != nil { return nil, err
return nil, err
}
} }
return &Server{ return &Server{
config: conf, config: conf,
cache: cache, cache: cache,
firebase: firebaseSubscriber, firebase: firebaseSubscriber,
topics: topics, topics: topics,
visitors: make(map[string]*visitor), visitors: make(map[string]*visitor),
}, nil }, nil
} }
func maybeCreateCache(conf *config.Config) (*cache, error) { func createCache(conf *config.Config) (cache, error) {
if conf.CacheFile == "" { if conf.CacheFile != "" {
return nil, nil return newSqliteCache(conf.CacheFile)
} }
return newCache(conf.CacheFile) return newMemCache(), nil
} }
func createFirebaseSubscriber(conf *config.Config) (subscriber, error) { func createFirebaseSubscriber(conf *config.Config) (subscriber, error) {
@ -202,8 +200,8 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
if err := t.Publish(m); err != nil { if err := t.Publish(m); err != nil {
return err return err
} }
if s.cache != nil { if err := s.cache.AddMessage(m); err != nil {
s.cache.Add(m) return err
} }
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
s.mu.Lock() s.mu.Lock()
@ -277,20 +275,18 @@ func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, v *visi
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
w.Header().Set("Content-Type", contentType) w.Header().Set("Content-Type", contentType)
if poll { if poll {
return sendOldMessages(t, since, sub) return s.sendOldMessages(t, since, sub)
} }
subscriberID := t.Subscribe(sub) subscriberID := t.Subscribe(sub)
defer t.Unsubscribe(subscriberID) defer t.Unsubscribe(subscriberID)
if err := sub(newOpenMessage(t.id)); err != nil { // Send out open message if err := sub(newOpenMessage(t.id)); err != nil { // Send out open message
return err return err
} }
if err := sendOldMessages(t, since, sub); err != nil { if err := s.sendOldMessages(t, since, sub); err != nil {
return err return err
} }
for { for {
select { select {
case <-t.ctx.Done():
return nil
case <-r.Context().Done(): case <-r.Context().Done():
return nil return nil
case <-time.After(s.config.KeepaliveInterval): case <-time.After(s.config.KeepaliveInterval):
@ -302,11 +298,15 @@ func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, v *visi
} }
} }
func sendOldMessages(t *topic, since time.Time, sub subscriber) error { func (s *Server) sendOldMessages(t *topic, since time.Time, sub subscriber) error {
if since.IsZero() { if since.IsZero() {
return nil return nil
} }
for _, m := range t.Messages(since) { messages, err := s.cache.Messages(t.id, since)
if err != nil {
return err
}
for _, m := range messages {
if err := sub(m); err != nil { if err := sub(m); err != nil {
return err return err
} }
@ -340,7 +340,7 @@ func (s *Server) topic(id string) (*topic, error) {
if len(s.topics) >= s.config.GlobalTopicLimit { if len(s.topics) >= s.config.GlobalTopicLimit {
return nil, errHTTPTooManyRequests return nil, errHTTPTooManyRequests
} }
s.topics[id] = newTopic(id) s.topics[id] = newTopic(id, time.Now())
if s.firebase != nil { if s.firebase != nil {
s.topics[id].Subscribe(s.firebase) s.topics[id].Subscribe(s.firebase)
} }
@ -360,28 +360,28 @@ func (s *Server) updateStatsAndExpire() {
} }
// Prune cache // Prune cache
if s.cache != nil { if err := s.cache.Prune(s.config.MessageBufferDuration); err != nil {
if err := s.cache.Prune(s.config.MessageBufferDuration); err != nil { log.Printf("error pruning cache: %s", err.Error())
log.Printf("error pruning cache: %s", err.Error())
}
} }
// Prune old messages, remove subscriptions without subscribers // Prune old messages, remove subscriptions without subscribers
for _, t := range s.topics {
t.Prune(s.config.MessageBufferDuration)
subs, msgs := t.Stats()
if msgs == 0 && (subs == 0 || (s.firebase != nil && subs == 1)) { // Firebase is a subscriber!
delete(s.topics, t.id)
}
}
// Print stats
var subscribers, messages int var subscribers, messages int
for _, t := range s.topics { for _, t := range s.topics {
subs, msgs := t.Stats() subs := t.Subscribers()
msgs, err := s.cache.MessageCount(t.id)
if err != nil {
log.Printf("cannot get stats for topic %s: %s", t.id, err.Error())
continue
}
if msgs == 0 && (subs == 0 || (s.firebase != nil && subs == 1)) { // Firebase is a subscriber!
delete(s.topics, t.id)
continue
}
subscribers += subs subscribers += subs
messages += msgs messages += msgs
} }
// Print stats
log.Printf("Stats: %d message(s) published, %d topic(s) active, %d subscriber(s), %d message(s) buffered, %d visitor(s)", log.Printf("Stats: %d message(s) published, %d topic(s) active, %d subscriber(s), %d message(s) buffered, %d visitor(s)",
s.messages, len(s.topics), subscribers, messages, len(s.visitors)) s.messages, len(s.topics), subscribers, messages, len(s.visitors))
} }

View File

@ -1,7 +1,6 @@
package server package server
import ( import (
"context"
"log" "log"
"math/rand" "math/rand"
"sync" "sync"
@ -12,11 +11,8 @@ import (
// can publish a message // can publish a message
type topic struct { type topic struct {
id string id string
subscribers map[int]subscriber
messages []*message
last time.Time last time.Time
ctx context.Context subscribers map[int]subscriber
cancel context.CancelFunc
mu sync.Mutex mu sync.Mutex
} }
@ -24,15 +20,11 @@ type topic struct {
type subscriber func(msg *message) error type subscriber func(msg *message) error
// newTopic creates a new topic // newTopic creates a new topic
func newTopic(id string) *topic { func newTopic(id string, last time.Time) *topic {
ctx, cancel := context.WithCancel(context.Background())
return &topic{ return &topic{
id: id, id: id,
last: last,
subscribers: make(map[int]subscriber), subscribers: make(map[int]subscriber),
messages: make([]*message, 0),
last: time.Now(),
ctx: ctx,
cancel: cancel,
} }
} }
@ -55,7 +47,6 @@ func (t *topic) Publish(m *message) error {
t.mu.Lock() t.mu.Lock()
defer t.mu.Unlock() defer t.mu.Unlock()
t.last = time.Now() t.last = time.Now()
t.messages = append(t.messages, m)
for _, s := range t.subscribers { for _, s := range t.subscribers {
if err := s(m); err != nil { if err := s(m); err != nil {
log.Printf("error publishing message to subscriber") log.Printf("error publishing message to subscriber")
@ -64,38 +55,8 @@ func (t *topic) Publish(m *message) error {
return nil return nil
} }
func (t *topic) Messages(since time.Time) []*message { func (t *topic) Subscribers() int {
t.mu.Lock() t.mu.Lock()
defer t.mu.Unlock() defer t.mu.Unlock()
messages := make([]*message, 0) // copy! return len(t.subscribers)
for _, m := range t.messages {
msgTime := time.Unix(m.Time, 0)
if msgTime == since || msgTime.After(since) {
messages = append(messages, m)
}
}
return messages
}
func (t *topic) Prune(keep time.Duration) {
t.mu.Lock()
defer t.mu.Unlock()
for i, m := range t.messages {
msgTime := time.Unix(m.Time, 0)
if time.Since(msgTime) < keep {
t.messages = t.messages[i:]
return
}
}
t.messages = make([]*message, 0)
}
func (t *topic) Stats() (subscribers int, messages int) {
t.mu.Lock()
defer t.mu.Unlock()
return len(t.subscribers), len(t.messages)
}
func (t *topic) Close() {
t.cancel()
} }