1
0
mirror of https://github.com/TehPeGaSuS/GitBot.git synced 2026-06-25 04:25:45 +02:00
Files
2026-05-05 13:39:16 +02:00

409 lines
15 KiB
Python

"""
gitbot — main entry point.
Usage:
python bot.py [gitbot.toml] # normal run
python bot.py --setup [gitbot.toml] # create/reset owner account, then run
"""
import asyncio
import getpass
import itertools
import logging
import sys
try:
import tomllib
except ImportError:
try:
import tomli as tomllib
except ImportError:
print("Python < 3.11 detected. Install tomli: pip install tomli",
file=sys.stderr)
sys.exit(1)
import auth
import commands
import db
import irc_format as fmt
import rss as rss_module
import shlink as shlink_module
import webhook_github
import webhook_gitea
import webhook_gitlab
from irc_client import IRCClient
from webhook_server import WebhookServer
log = logging.getLogger("bot")
PARSERS = {
"github": webhook_github,
"gitea": webhook_gitea,
"gitlab": webhook_gitlab,
}
DEFAULT_EVENTS = {"ping", "code", "pr", "issue", "repo"}
def load_config(path: str) -> dict:
with open(path, "rb") as f:
return tomllib.load(f)
def run_setup(database):
"""Interactive terminal setup for the owner account."""
print()
if auth.has_owner(database):
print("An owner account already exists.")
answer = input("Reset it? [y/N] ").strip().lower()
if answer != "y":
print("Setup cancelled.")
return
print("── gitbot owner account setup ──────────────────")
nick = input("Owner nick: ").strip()
if not nick:
print("Nick cannot be empty.")
sys.exit(1)
while True:
password = getpass.getpass("Password: ")
password2 = getpass.getpass("Confirm password: ")
if not password:
print("Password cannot be empty.")
elif password != password2:
print("Passwords do not match, try again.")
else:
break
auth.create_owner(database, nick, password)
print(f"Owner account created for '{nick}'.")
print("You can now /msg the bot: identify <password>")
print()
class Bot:
def __init__(self, config: dict, config_path: str):
self._cfg = config
self._config_path = config_path
self._database = db.connect(config.get("database", "gitbot.db"))
self._clients: dict[str, IRCClient] = {}
self._shlink = shlink_module.from_config(config.get("shlink", {}))
if self._shlink:
log.info("Shlink URL shortener enabled")
self._load_static_webhooks()
self._load_static_rss()
@property
def _commit_limit(self) -> int:
return self._cfg.get("commit_limit", 3)
# ── Static config loading ─────────────────────────────────────────────────
def _load_static_webhooks(self):
for net_cfg in self._cfg.get("network", []):
net = net_cfg["name"]
for ch_cfg in net_cfg.get("channel", []):
ch = ch_cfg["name"]
for hook in ch_cfg.get("webhook", []):
db.webhook_add(
self._database, net, ch,
hook["repo"],
hook.get("forge"),
hook.get("events", list(DEFAULT_EVENTS)),
hook.get("branches", []),
)
def _load_static_rss(self):
for net_cfg in self._cfg.get("network", []):
net = net_cfg["name"]
for ch_cfg in net_cfg.get("channel", []):
ch = ch_cfg["name"]
for url in ch_cfg.get("rss", []):
db.rss_add(self._database, net, ch, url) # (id, created) ignored here
# ── IRC message routing ───────────────────────────────────────────────────
async def _on_message(self, network: str, target: str, nick: str,
prefix: str, text: str):
own_nick = self._clients[network].config["nickname"]
if target.lower() == own_nick.lower():
# Private message to the bot
async def pm_reply(msg):
self._clients[network].privmsg(nick, msg)
await commands.handle_pm(
network, nick, prefix, text,
self._database, pm_reply, self.reload,
self.join_channel, self.part_channel)
elif target.startswith("#"):
# Channel message
async def ch_reply(msg):
await self._deliver_irc(network, target, msg)
await commands.handle_channel(
network, target, nick, prefix, text,
self._database, ch_reply, self.reload,
self.join_channel, self.part_channel)
async def _on_connected(self, network: str):
log.info("[%s] Connected and registered", network)
# ── Hot reload ────────────────────────────────────────────────────────────
async def join_channel(self, network: str, channel: str) -> str:
"""Join a channel on a network. Called from IRC commands."""
client = self._clients.get(network)
if not client:
nets = ", ".join(self._clients) or "none"
return f"Unknown network '{network}'. Connected networks: {nets}"
if client.in_channel(channel):
return f"Already in {channel} on {network}."
client.join(channel)
log.info("Joining %s on %s (via command)", channel, network)
return f"Joining {channel} on {network}."
async def part_channel(self, network: str, channel: str) -> str:
"""Part a channel on a network. Called from IRC commands."""
client = self._clients.get(network)
if not client:
nets = ", ".join(self._clients) or "none"
return f"Unknown network '{network}'. Connected networks: {nets}"
if not client.in_channel(channel):
return f"Not in {channel} on {network}."
client.send(f"PART {channel} :Leaving")
db.purge_channel(self._database, network, channel)
log.info("Parting %s on %s (via command)", channel, network)
return f"Left {channel} on {network}."
async def reload(self) -> str:
"""
Re-read the config file and reconcile network connections:
- New networks → connect
- Gone networks → QUIT and stop
- Kept networks → join any new channels, part removed ones
Returns a human-readable summary.
"""
try:
new_cfg = load_config(self._config_path)
except Exception as e:
return f"Failed to reload config: {e}"
self._cfg = new_cfg
self._shlink = shlink_module.from_config(new_cfg.get("shlink", {}))
self._load_static_webhooks()
self._load_static_rss()
global_bind = new_cfg.get("bind")
new_nets = {n["name"]: n for n in new_cfg.get("network", [])}
current_nets = set(self._clients.keys())
added = []
removed = []
updated = []
# Disconnect networks that are no longer in config
for name in list(current_nets):
if name not in new_nets:
log.info("Reload: disconnecting %s", name)
await self._clients[name].stop("Configuration removed")
del self._clients[name]
db.purge_network(self._database, name)
removed.append(name)
# Connect new networks; reconcile channels on existing ones
for name, net_cfg in new_nets.items():
if global_bind and "bind" not in net_cfg:
net_cfg = {**net_cfg, "_global_bind": global_bind}
if name not in self._clients:
# Brand new network
log.info("Reload: connecting new network %s", name)
self._start_network(net_cfg)
added.append(name)
else:
# Existing network — reconcile channels
client = self._clients[name]
new_chans = {c.lower() for c in net_cfg.get("channels", [])}
cur_chans = set(client._channels) # already lowercase
for ch in new_chans - cur_chans:
log.info("Reload: joining %s on %s", ch, name)
client.join(ch)
for ch in cur_chans - new_chans:
log.info("Reload: parting %s on %s", ch, name)
client.send(f"PART {ch} :Removed from config")
db.purge_channel(self._database, name, ch)
if new_chans != cur_chans:
updated.append(name)
parts = []
if added:
parts.append(f"connected: {', '.join(added)}")
if removed:
parts.append(f"disconnected: {', '.join(removed)}")
if updated:
parts.append(f"channels updated: {', '.join(updated)}")
return "Reloaded. " + ("; ".join(parts) if parts else "no network changes.")
def _start_network(self, net_cfg: dict):
client = IRCClient(
net_cfg,
on_message=self._on_message,
on_connected=self._on_connected,
)
self._clients[net_cfg["name"]] = client
asyncio.create_task(client.run())
# ── IRC delivery ──────────────────────────────────────────────────────────
async def _deliver_irc(self, network: str, channel: str, message: str):
client = self._clients.get(network)
if not client:
log.warning("No client for network %s", network)
return
if not client.in_channel(channel):
log.debug("[%s] Not in %s yet, joining…", network, channel)
client.join(channel)
await asyncio.sleep(2)
client.privmsg(channel, message)
# ── Webhook delivery ──────────────────────────────────────────────────────
async def _on_webhook(self, forge: str, headers: dict, data: dict):
parser = PARSERS.get(forge)
if not parser:
return
full_name, repo_user, repo_name, organisation = parser.names(data, headers)
branch = parser.branch(data, headers)
events = parser.event(data, headers)
primary = events[0] if events else ""
targets = db.webhook_targets(
self._database, forge, full_name, repo_user, organisation)
if not targets:
log.debug("[%s] No targets for %s", forge, full_name)
return
outputs = parser.parse(full_name, primary, data, headers,
commit_limit=self._commit_limit)
if not outputs:
return
forge_tag = fmt.bold(f"[{forge.capitalize()}]")
source = fmt.color(full_name or organisation or repo_name or forge,
fmt.COLOR_REPO)
for target in targets:
if branch and target["branches"] and branch not in target["branches"]:
continue
allowed = set(itertools.chain.from_iterable(
parser.event_categories(e) for e in target["events"]
))
if not set(events) & allowed:
continue
for message, url in outputs:
if url and self._shlink:
url = await self._shlink.shorten(url)
line = f"{forge_tag} ({source}) {message}"
if url:
line = f"{line} - {url}"
await self._deliver_irc(target["network"], target["channel"], line)
# ── Main run ──────────────────────────────────────────────────────────────
async def run(self):
global_bind = self._cfg.get("bind")
for net_cfg in self._cfg.get("network", []):
if global_bind and "bind" not in net_cfg:
net_cfg = {**net_cfg, "_global_bind": global_bind}
self._start_network(net_cfg)
wh_cfg = self._cfg.get("webhook_server", {})
if wh_cfg.get("enabled", True):
# Per-forge secrets; fall back to legacy 'secret' key for all forges
legacy = wh_cfg.get("secret", "")
secrets = {
"github": wh_cfg.get("github_secret", legacy),
"gitea": wh_cfg.get("gitea_secret", legacy),
"gitlab": wh_cfg.get("gitlab_secret", legacy),
}
server = WebhookServer(
host=wh_cfg.get("host", "127.0.0.1"),
port=wh_cfg.get("port", 8080),
deliver=self._on_webhook,
secrets=secrets,
)
asyncio.create_task(server.run())
rss_cfg = self._cfg.get("rss", {})
if rss_cfg.get("enabled", True):
poller = rss_module.RSSPoller(
database=self._database,
deliver=self._deliver_irc,
interval=rss_cfg.get("interval", 300),
)
asyncio.create_task(poller.run())
log.info("gitbot started")
await asyncio.Event().wait() # run forever
def main():
import argparse
parser = argparse.ArgumentParser(
description="gitbot — git webhook + RSS IRC bot")
parser.add_argument("-c", "--config", default="gitbot.toml", metavar="FILE",
help="Path to TOML config file (default: gitbot.toml)")
parser.add_argument("--setup", action="store_true",
help="Create or reset the owner account, then start the bot")
parser.add_argument("-v", "--verbose", action="store_true",
help="Enable debug logging")
args = parser.parse_args()
logging.basicConfig(
level=logging.DEBUG if args.verbose else logging.INFO,
format="%(asctime)s %(levelname)-7s %(name)s: %(message)s",
datefmt="%H:%M:%S",
)
try:
config = load_config(args.config)
except FileNotFoundError:
print(f"Config file not found: {args.config}", file=sys.stderr)
sys.exit(1)
except Exception as e:
print(f"Config error: {e}", file=sys.stderr)
sys.exit(1)
# Open DB early so setup can use it
database = db.connect(config.get("database", "gitbot.db"))
if args.setup:
run_setup(database)
elif not auth.has_owner(database):
print("No owner account found. Run with --setup first:", file=sys.stderr)
print(f" python bot.py --setup -c {args.config}", file=sys.stderr)
sys.exit(1)
bot = Bot(config, args.config)
# Re-use the already-open DB connection
bot._database = database
try:
asyncio.run(bot.run())
except KeyboardInterrupt:
log.info("Shutting down")
if __name__ == "__main__":
main()