code refactor

This commit is contained in:
Astra 2026-03-23 07:59:53 +00:00
parent 20048736eb
commit cf6e27539b
3 changed files with 423 additions and 245 deletions

584
bot.go
View file

@ -11,6 +11,7 @@ import (
"os"
"regexp"
"strings"
"sync"
"github.com/gotd/td/session"
"github.com/gotd/td/telegram"
@ -39,6 +40,7 @@ type Config struct {
NtfyTopic string `yaml:"ntfy_topic"`
SessionPath string `yaml:"session_path"`
YAMLPatterns []YAMLScamPattern `yaml:"scam_patterns"`
BadWords []string `yaml:"bad_words"`
Patterns []scamPattern `yaml:"-"` // Compiled patterns, not from YAML
}
@ -56,6 +58,7 @@ ntfy_token: ""
ntfy_topic: ""
session_path: "antiscam.session"
scam_patterns: []
bad_words: []
`
if err := os.WriteFile(path, []byte(defaultConfig), 0644); err != nil {
return nil, fmt.Errorf("creating default config file: %w", err)
@ -158,6 +161,148 @@ var scamPatterns = []scamPattern{
},
}
// BotState holds shared runtime state that may be updated at runtime.
type BotState struct {
cfg *Config
configPath string
selfID int64
badWordRe *regexp.Regexp
mu sync.RWMutex
}
func newBotState(cfg *Config, configPath string, selfID int64) *BotState {
s := &BotState{cfg: cfg, configPath: configPath, selfID: selfID}
s.compileBadWords()
return s
}
// compileBadWords rebuilds the bad word regex from cfg.BadWords.
// Must be called with mu held (write).
func (s *BotState) compileBadWords() {
if len(s.cfg.BadWords) == 0 {
s.badWordRe = nil
return
}
escaped := make([]string, len(s.cfg.BadWords))
for i, w := range s.cfg.BadWords {
escaped[i] = regexp.QuoteMeta(w)
}
s.badWordRe = regexp.MustCompile(`(?i)(` + strings.Join(escaped, "|") + `)`)
log.Printf("✓ Bad word pattern: %s", s.badWordRe)
}
func (s *BotState) saveConfig() error {
data, err := yaml.Marshal(s.cfg)
if err != nil {
return fmt.Errorf("marshaling config: %w", err)
}
return os.WriteFile(s.configPath, data, 0644)
}
func sendAdminReply(ctx context.Context, api *tg.Client, alertPeer *tg.InputPeerChannel, text string) {
sender := message.NewSender(api)
_, err := sender.To(alertPeer).Text(ctx, text)
if err != nil {
log.Printf("Failed to send admin reply: %v", err)
}
}
func handleOwnerCommand(ctx context.Context, api *tg.Client, state *BotState, alertPeer *tg.InputPeerChannel, text string) {
text = strings.TrimSpace(text)
if !strings.HasPrefix(text, "/badword") {
return
}
parts := strings.Fields(text)
if len(parts) < 2 {
sendAdminReply(ctx, api, alertPeer, "Usage: /badword <list|add|rem> [word]")
return
}
subCmd := strings.ToLower(parts[1])
state.mu.Lock()
defer state.mu.Unlock()
switch subCmd {
case "list":
if len(state.cfg.BadWords) == 0 {
sendAdminReply(ctx, api, alertPeer, "No bad words configured.")
return
}
var sb strings.Builder
sb.WriteString("Bad words:\n")
for _, w := range state.cfg.BadWords {
sb.WriteString(" • ")
sb.WriteString(w)
sb.WriteString("\n")
}
reply := sb.String()
if state.badWordRe != nil {
reply += "\nPattern: " + state.badWordRe.String()
}
sendAdminReply(ctx, api, alertPeer, strings.TrimRight(reply, "\n"))
case "add":
if len(parts) < 3 {
sendAdminReply(ctx, api, alertPeer, "Usage: /badword add <word>")
return
}
word := parts[2]
for _, w := range state.cfg.BadWords {
if strings.EqualFold(w, word) {
sendAdminReply(ctx, api, alertPeer, fmt.Sprintf("Word %q already exists.", word))
return
}
}
state.cfg.BadWords = append(state.cfg.BadWords, word)
state.compileBadWords()
if err := state.saveConfig(); err != nil {
log.Printf("Failed to save config: %v", err)
sendAdminReply(ctx, api, alertPeer, fmt.Sprintf("Added %q but failed to save config: %v", word, err))
return
}
reply := fmt.Sprintf("Added bad word: %q", word)
if state.badWordRe != nil {
reply += "\nPattern: " + state.badWordRe.String()
}
sendAdminReply(ctx, api, alertPeer, reply)
case "rem":
if len(parts) < 3 {
sendAdminReply(ctx, api, alertPeer, "Usage: /badword rem <word>")
return
}
word := parts[2]
newWords := make([]string, 0, len(state.cfg.BadWords))
found := false
for _, w := range state.cfg.BadWords {
if strings.EqualFold(w, word) {
found = true
continue
}
newWords = append(newWords, w)
}
if !found {
sendAdminReply(ctx, api, alertPeer, fmt.Sprintf("Word %q not found.", word))
return
}
state.cfg.BadWords = newWords
state.compileBadWords()
if err := state.saveConfig(); err != nil {
log.Printf("Failed to save config: %v", err)
}
patternStr := "none"
if state.badWordRe != nil {
patternStr = state.badWordRe.String()
}
sendAdminReply(ctx, api, alertPeer, fmt.Sprintf("Removed bad word: %q\nPattern: %s", word, patternStr))
default:
sendAdminReply(ctx, api, alertPeer, "Unknown subcommand. Usage: /badword <list|add|rem> [word]")
}
}
func keywordMatchScore(input string, patterns []scamPattern) matchResult {
matched := make(map[string][]string)
totalScore := 0.0
@ -195,33 +340,23 @@ func escapeMarkdown(text string) string {
return result.String()
}
func escapeHTML(text string) string {
var result strings.Builder
for _, r := range text {
switch r {
case '&':
result.WriteString("&amp;")
case '<':
result.WriteString("&lt;")
case '>':
result.WriteString("&gt;")
case '"':
result.WriteString("&quot;")
case '\'':
result.WriteString("&#39;")
default:
result.WriteRune(r)
}
}
return result.String()
func fullName(first, last string) string {
return strings.TrimSpace(first + " " + last)
}
func notify(message, host, topic, title string, priority int, token string) {
if !strings.HasSuffix(host, "/") {
host = host + "/"
func formatUsername(username string) string {
if username == "" {
return ""
}
url := host + topic
req, err := http.NewRequest("POST", url, bytes.NewBufferString(message))
return " (@" + username + ")"
}
func notify(body string, cfg *Config, title string) {
host := cfg.NtfyHost
if !strings.HasSuffix(host, "/") {
host += "/"
}
req, err := http.NewRequest("POST", host+cfg.NtfyTopic, bytes.NewBufferString(body))
if err != nil {
log.Printf("notify: creating request: %v", err)
return
@ -229,11 +364,10 @@ func notify(message, host, topic, title string, priority int, token string) {
if title != "" {
req.Header.Set("Title", title)
}
req.Header.Set("Priority", fmt.Sprintf("%d", priority))
req.Header.Set("Priority", "5")
req.Header.Set("Markdown", "yes")
if token != "" {
// ntfy uses basic auth with empty username and token as password
encoded := base64.StdEncoding.EncodeToString([]byte(":" + token))
if cfg.NtfyToken != "" {
encoded := base64.StdEncoding.EncodeToString([]byte(":" + cfg.NtfyToken))
req.Header.Set("Authorization", "Basic "+encoded)
}
resp, err := http.DefaultClient.Do(req)
@ -244,94 +378,162 @@ func notify(message, host, topic, title string, priority int, token string) {
defer resp.Body.Close()
}
func resolveAlertPeer(ctx context.Context, api *tg.Client, cfg *Config) (*tg.InputPeerChannel, error) {
result, err := api.MessagesGetDialogs(ctx, &tg.MessagesGetDialogsRequest{
func resolveChannel(ctx context.Context, api *tg.Client, channelID int64) (*tg.InputPeerChannel, *tg.Channel, error) {
req := &tg.MessagesGetDialogsRequest{
OffsetPeer: &tg.InputPeerEmpty{},
Limit: 100,
})
if err != nil {
return nil, err
}
var chats []tg.ChatClass
for {
result, err := api.MessagesGetDialogs(ctx, req)
if err != nil {
return nil, nil, err
}
switch v := result.(type) {
case *tg.MessagesDialogs:
chats = v.Chats
case *tg.MessagesDialogsSlice:
chats = v.Chats
case *tg.MessagesDialogsNotModified:
return nil, fmt.Errorf("dialogs not modified")
default:
return nil, fmt.Errorf("unexpected dialogs type: %T", result)
}
var (
chats []tg.ChatClass
dialogs []tg.DialogClass
messages []tg.MessageClass
done bool
)
for _, chat := range chats {
if channel, ok := chat.(*tg.Channel); ok {
if channel.ID == cfg.AlertChat {
return &tg.InputPeerChannel{
ChannelID: channel.ID,
AccessHash: channel.AccessHash,
}, nil
switch v := result.(type) {
case *tg.MessagesDialogs:
chats, dialogs, messages = v.Chats, v.Dialogs, v.Messages
done = true
case *tg.MessagesDialogsSlice:
chats, dialogs, messages = v.Chats, v.Dialogs, v.Messages
case *tg.MessagesDialogsNotModified:
return nil, nil, fmt.Errorf("dialogs not modified")
default:
return nil, nil, fmt.Errorf("unexpected dialogs type: %T", result)
}
for _, chat := range chats {
if channel, ok := chat.(*tg.Channel); ok {
if channel.ID == channelID {
return &tg.InputPeerChannel{
ChannelID: channel.ID,
AccessHash: channel.AccessHash,
}, channel, nil
}
}
}
if done || len(dialogs) == 0 {
break
}
// Build offset from the last dialog for the next page.
last, ok := dialogs[len(dialogs)-1].(*tg.Dialog)
if !ok {
break
}
msgID := last.TopMessage
var msgDate int
for _, m := range messages {
if msg, ok := m.(*tg.Message); ok && msg.ID == msgID {
msgDate = msg.Date
break
}
}
if msgDate == 0 {
break
}
offsetPeer, err := toPeerInput(last.Peer, chats)
if err != nil {
break
}
req.OffsetDate = msgDate
req.OffsetID = msgID
req.OffsetPeer = offsetPeer
}
return nil, fmt.Errorf("alert channel %d not found", cfg.AlertChat)
return nil, nil, fmt.Errorf("channel %d not found in dialogs", channelID)
}
func handleNewMessage(ctx context.Context, api *tg.Client, alertPeer *tg.InputPeerChannel, cfg *Config, entities tg.Entities, update any) error {
// Extract message from either UpdateNewMessage or UpdateNewChannelMessage
var msg *tg.Message
func toPeerInput(peer tg.PeerClass, chats []tg.ChatClass) (tg.InputPeerClass, error) {
switch p := peer.(type) {
case *tg.PeerChannel:
for _, c := range chats {
if ch, ok := c.(*tg.Channel); ok && ch.ID == p.ChannelID {
return &tg.InputPeerChannel{ChannelID: ch.ID, AccessHash: ch.AccessHash}, nil
}
}
return nil, fmt.Errorf("channel %d not found in chats", p.ChannelID)
case *tg.PeerChat:
return &tg.InputPeerChat{ChatID: p.ChatID}, nil
case *tg.PeerUser:
return &tg.InputPeerUser{UserID: p.UserID}, nil
default:
return nil, fmt.Errorf("unknown peer type %T", peer)
}
}
func extractMessage(update any) *tg.Message {
switch u := update.(type) {
case *tg.UpdateNewMessage:
m, ok := u.Message.(*tg.Message)
if !ok {
return nil
}
msg = m
m, _ := u.Message.(*tg.Message)
return m
case *tg.UpdateNewChannelMessage:
m, ok := u.Message.(*tg.Message)
if !ok {
return nil
}
msg = m
m, _ := u.Message.(*tg.Message)
return m
default:
log.Printf("DEBUG: unknown update type %T", update)
return nil
}
}
// Only process messages sent by the bot (Out=true)
func handleNewMessage(ctx context.Context, api *tg.Client, alertPeer *tg.InputPeerChannel, state *BotState, entities tg.Entities, update any) error {
msg := extractMessage(update)
if msg == nil {
return nil
}
// Handle commands from the bot owner (messages sent by the connected account).
if msg.Out {
handleOwnerCommand(ctx, api, state, alertPeer, msg.Message)
return nil
}
// Check if from monitored supergroup (channels in gotd terminology)
peerChannel, ok := msg.PeerID.(*tg.PeerChannel)
if !ok || int64(peerChannel.ChannelID) != cfg.MonitoredChat {
if !ok || int64(peerChannel.ChannelID) != state.cfg.MonitoredChat {
return nil
}
log.Printf("✓ Processing message in chat %v: %s", msg.PeerID, msg.Message)
chatID := int64(peerChannel.ChannelID)
// Use YAML patterns if available, otherwise use defaults
patterns := cfg.Patterns
if len(patterns) == 0 {
patterns = scamPatterns
log.Printf("Using hardcoded scam patterns (%d patterns)", len(patterns))
// Check bad words first — any match sets score to 1.0 immediately.
state.mu.RLock()
badWordRe := state.badWordRe
state.mu.RUnlock()
var result matchResult
if badWordRe != nil && badWordRe.MatchString(msg.Message) {
result = matchResult{
score: 1.0,
matchedKeywords: map[string][]string{"bad_words": badWordRe.FindAllString(msg.Message, -1)},
}
log.Printf("Bad word match in message, score set to 1.0")
} else {
log.Printf("Using YAML-loaded scam patterns (%d patterns)", len(patterns))
}
// Use YAML patterns if available, otherwise use defaults.
patterns := state.cfg.Patterns
if len(patterns) == 0 {
patterns = scamPatterns
log.Printf("Using hardcoded scam patterns (%d patterns)", len(patterns))
} else {
log.Printf("Using YAML-loaded scam patterns (%d patterns)", len(patterns))
}
// Score the message
result := keywordMatchScore(msg.Message, patterns)
result = keywordMatchScore(msg.Message, patterns)
if extended_latin, ok := result.matchedKeywords["extended_latin"]; ok {
for range extended_latin {
result.score += 0.1
if result.score >= 1.0 {
result.score = 1.0
break
if extended_latin, ok := result.matchedKeywords["extended_latin"]; ok {
for range extended_latin {
result.score += 0.1
if result.score >= 1.0 {
result.score = 1.0
break
}
}
}
}
@ -340,113 +542,90 @@ func handleNewMessage(ctx context.Context, api *tg.Client, alertPeer *tg.InputPe
return nil
}
if result.score == 1.0 ||
(result.score == 1.0 &&
result.matchedKeywords["extended_latin"] != nil &&
len(result.matchedKeywords["links"]) > 0) {
log.Printf("Matched message with score %.2f", result.score)
for key, values := range result.matchedKeywords {
for _, value := range values {
log.Printf(" %s: %s", key, value)
}
log.Printf("Matched message with score %.2f", result.score)
for key, values := range result.matchedKeywords {
for _, value := range values {
log.Printf(" %s: %s", key, value)
}
}
// Delete the message from supergroup
channel, ok := entities.Channels[chatID]
if !ok {
return fmt.Errorf("channel %d not found in entities", chatID)
}
_, err := api.ChannelsDeleteMessages(ctx, &tg.ChannelsDeleteMessagesRequest{
Channel: &tg.InputChannel{
ChannelID: chatID,
AccessHash: channel.AccessHash,
},
ID: []int{msg.ID},
})
if err != nil {
log.Printf("Failed to delete message: %v", err)
}
// Delete the message from supergroup
channel, ok := entities.Channels[chatID]
if !ok {
return fmt.Errorf("channel %d not found in entities", chatID)
}
_, err := api.ChannelsDeleteMessages(ctx, &tg.ChannelsDeleteMessagesRequest{
Channel: &tg.InputChannel{
ChannelID: chatID,
AccessHash: channel.AccessHash,
},
ID: []int{msg.ID},
})
if err != nil {
log.Printf("Failed to delete message: %v", err)
}
// Get sender info
var senderID int64
if fromID, ok := msg.FromID.(*tg.PeerUser); ok {
senderID = int64(fromID.UserID)
} else {
return fmt.Errorf("could not determine sender")
}
// Get sender info
fromID, ok := msg.FromID.(*tg.PeerUser)
if !ok {
return fmt.Errorf("could not determine sender")
}
senderID := int64(fromID.UserID)
user, ok := entities.Users[senderID]
if !ok {
return fmt.Errorf("user %d not found in entities", senderID)
}
user, ok := entities.Users[senderID]
if !ok {
return fmt.Errorf("user %d not found in entities", senderID)
}
displayName := user.FirstName
if user.LastName != "" {
displayName += " " + user.LastName
}
displayName = strings.TrimSpace(displayName)
displayName := fullName(user.FirstName, user.LastName)
userDisplay := displayName
if user.Username != "" {
userDisplay += " (@" + user.Username + ")"
} else {
userDisplay += " (no username)"
}
username := "no username"
if user.Username != "" {
username = "@" + user.Username
}
// Restrict sender in supergroup
_, err = api.ChannelsEditBanned(ctx, &tg.ChannelsEditBannedRequest{
Channel: &tg.InputChannel{
ChannelID: chatID,
AccessHash: channel.AccessHash,
},
Participant: &tg.InputPeerUser{
UserID: senderID,
AccessHash: user.AccessHash,
},
BannedRights: tg.ChatBannedRights{
SendMessages: true,
SendMedia: true,
SendStickers: true,
SendGifs: true,
UntilDate: 0,
},
})
if err != nil {
log.Printf("Failed to restrict user %d: %v", senderID, err)
}
// Restrict sender in supergroup
_, err = api.ChannelsEditBanned(ctx, &tg.ChannelsEditBannedRequest{
Channel: &tg.InputChannel{
ChannelID: chatID,
AccessHash: channel.AccessHash,
},
Participant: &tg.InputPeerUser{
UserID: senderID,
AccessHash: user.AccessHash,
},
BannedRights: tg.ChatBannedRights{
SendMessages: true,
SendMedia: true,
SendStickers: true,
SendGifs: true,
UntilDate: 0,
},
})
if err != nil {
log.Printf("Failed to restrict user %d: %v", senderID, err)
}
chatName := channel.Title
chatDisplay := escapeMarkdown(chatName) + escapeMarkdown(formatUsername(channel.Username))
// Get supergroup name and username
chatName := channel.Title
chatDisplay := escapeMarkdown(chatName)
if channel.Username != "" {
chatDisplay += " (@" + escapeMarkdown(channel.Username) + ")"
}
matchMessageHTML := fmt.Sprintf("🚨 Matched\n<b>Score</b>: %.2f\n<b>Chat</b>: %s (ID: %d)\n<b>User</b>: %s (ID: %d)\n",
result.score, chatDisplay, chatID, userDisplay, senderID)
// Build alert message with HTML formatting for markdown v2
matchMessageHTML := fmt.Sprintf("🚨 Matched\n<b>Score</b>: %.2f\n<b>Chat</b>: %s (ID: %d)\n<b>User</b>: %s (ID: %d)\n",
result.score, chatDisplay, chatID, displayName+" ("+username+")", senderID)
if state.cfg.NtfyToken != "" || state.cfg.NtfyTopic != "" {
plainMessage := fmt.Sprintf("🚨 Matched\nScore: %.2f\nChat: %s (ID: %d)\nUser: %s (ID: %d)\n\n%s",
result.score, chatName+formatUsername(channel.Username), chatID, userDisplay, senderID, msg.Message)
notify(plainMessage, state.cfg, fmt.Sprintf("Scam Alert: %s", chatName))
}
// Send ntfy notification if config set (use plain text for ntfy)
if cfg.NtfyToken != "" || cfg.NtfyTopic != "" {
plainMessage := fmt.Sprintf("🚨 Matched\nScore: %.2f\nChat: %s (ID: %d)\nUser: %s (ID: %d)\n\n%s",
result.score, escapeMarkdown(chatDisplay), chatID, escapeMarkdown(displayName+" ("+username+")"), senderID, escapeMarkdown(msg.Message))
notify(plainMessage, cfg.NtfyHost, cfg.NtfyTopic, fmt.Sprintf("Scam Alert: %s", chatName), 5, cfg.NtfyToken)
}
// Create a resolver for user mentions (not needed for this message, but required by html.String)
userResolver := func(id int64) (tg.InputUserClass, error) {
return &tg.InputUserFromMessage{
Peer: alertPeer,
MsgID: 0,
UserID: id,
}, nil
}
// Send alert message to alert chat using StyledText with HTML
sender := message.NewSender(api)
_, err = sender.To(alertPeer).StyledText(ctx, html.String(userResolver, matchMessageHTML))
if err != nil {
log.Printf("Failed to send alert message: %v", err)
}
userResolver := func(id int64) (tg.InputUserClass, error) {
return &tg.InputUserFromMessage{Peer: alertPeer, MsgID: 0, UserID: id}, nil
}
sender := message.NewSender(api)
_, err = sender.To(alertPeer).StyledText(ctx, html.String(userResolver, matchMessageHTML))
if err != nil {
log.Printf("Failed to send alert message: %v", err)
}
return nil
@ -513,38 +692,37 @@ func main() {
api := client.API()
// Resolve alert channel peer at startup
alertPeer, err := resolveAlertPeer(ctx, api, cfg)
alertPeer, alertChannel, err := resolveChannel(ctx, api, cfg.AlertChat)
if err != nil {
return fmt.Errorf("resolving alert peer: %w", err)
return fmt.Errorf("resolving alert channel: %w", err)
}
_, monitoredChannel, err := resolveChannel(ctx, api, cfg.MonitoredChat)
if err != nil {
return fmt.Errorf("resolving monitored channel: %w", err)
}
// Register message handler for private/group messages
dispatcher.OnNewMessage(func(ctx context.Context, entities tg.Entities, update *tg.UpdateNewMessage) error {
return handleNewMessage(ctx, api, alertPeer, cfg, entities, update)
})
// Register handler for channel messages (supergroups)
dispatcher.OnNewChannelMessage(func(ctx context.Context, entities tg.Entities, update *tg.UpdateNewChannelMessage) error {
return handleNewMessage(ctx, api, alertPeer, cfg, entities, update)
})
// Get self ID for updates manager
// Get self ID for state and updates manager
status, err := client.Auth().Status(ctx)
if err != nil {
return err
}
state := newBotState(cfg, configPath, int64(status.User.ID))
// Register message handler for private/group messages
dispatcher.OnNewMessage(func(ctx context.Context, entities tg.Entities, update *tg.UpdateNewMessage) error {
return handleNewMessage(ctx, api, alertPeer, state, entities, update)
})
// Register handler for channel messages (supergroups)
dispatcher.OnNewChannelMessage(func(ctx context.Context, entities tg.Entities, update *tg.UpdateNewChannelMessage) error {
return handleNewMessage(ctx, api, alertPeer, state, entities, update)
})
user := status.User
username := ""
if user.Username != "" {
username = " (@" + user.Username + ")"
}
displayName := user.FirstName
if user.LastName != "" {
displayName += " " + user.LastName
}
log.Printf("✓ Logged in as: %s%s (ID: %d)", displayName, username, user.ID)
log.Printf("✓ Bot running, monitoring chat %d", cfg.MonitoredChat)
log.Printf("✓ Logged in as: %s%s (ID: %d)", fullName(user.FirstName, user.LastName), formatUsername(user.Username), user.ID)
log.Printf("✓ Monitoring chat: %s%s (ID: %d)", monitoredChannel.Title, formatUsername(monitoredChannel.Username), monitoredChannel.ID)
log.Printf("✓ Sending alerts to: %s%s (ID: %d)", alertChannel.Title, formatUsername(alertChannel.Username), alertChannel.ID)
return gapManager.Run(ctx, api, status.User.ID, updates.AuthOptions{})
})