mirror of
https://github.com/TehPeGaSuS/GitBot.git
synced 2026-06-25 09:05:46 +02:00
deb736254d
Added automatic color stripping for channels that restrict colors and improved mode handling.
216 lines
8.0 KiB
Python
216 lines
8.0 KiB
Python
"""Async IRC client with automatic mode handling and color stripping."""
|
|
|
|
import asyncio
|
|
import logging
|
|
import ssl
|
|
import re
|
|
import base64
|
|
from typing import Callable, Optional
|
|
|
|
log = logging.getLogger("irc")
|
|
|
|
RECONNECT_DELAY_MIN = 5
|
|
RECONNECT_DELAY_MAX = 300
|
|
|
|
def strip_formatting(text: str) -> str:
|
|
"""Removes all IRC color and style control codes (Bold, Color, Underline, etc)."""
|
|
# Regex handles \x02 (bold), \x03 (color), \x0f (reset), \x1f (underline), etc.
|
|
return re.sub(r'\x02|\x1f|\x16|\x1d|\x0f|\x03(?:\d{1,2}(?:,\d{1,2})?)?', '', text)
|
|
|
|
class IRCClient:
|
|
def __init__(self, config: dict, on_message: Callable,
|
|
on_connected: Callable):
|
|
self.config = config
|
|
self.name = config["name"]
|
|
self.on_message = on_message # async fn(network, channel, nick, msg)
|
|
self.on_connected = on_connected # async fn(network)
|
|
|
|
self._writer: Optional[asyncio.StreamWriter] = None
|
|
self._channels: set = set()
|
|
self._colorless_channels: set = set() # Channels with +c or +S active
|
|
self._reconnect_delay = RECONNECT_DELAY_MIN
|
|
self._running = True
|
|
self._ready = False
|
|
|
|
# ── Public API ────────────────────────────────────────────────────────────
|
|
|
|
async def run(self):
|
|
while self._running:
|
|
try:
|
|
await self._connect()
|
|
except Exception as e:
|
|
log.warning("[%s] Connection error: %s", self.name, e)
|
|
if not self._running:
|
|
break
|
|
log.info("[%s] Reconnecting in %ds…", self.name, self._reconnect_delay)
|
|
await asyncio.sleep(self._reconnect_delay)
|
|
self._reconnect_delay = min(
|
|
self._reconnect_delay * 2, RECONNECT_DELAY_MAX)
|
|
|
|
async def stop(self, message: str = "Disconnecting"):
|
|
self._running = False
|
|
if self._writer:
|
|
try:
|
|
self.send(f"QUIT :{message}")
|
|
await asyncio.sleep(0.5)
|
|
self._writer.close()
|
|
except Exception:
|
|
pass
|
|
|
|
def send(self, line: str):
|
|
if self._writer:
|
|
log.debug("[%s] >> %s", self.name, line)
|
|
self._writer.write((line + "\r\n").encode())
|
|
|
|
def privmsg(self, target: str, text: str):
|
|
# Automatically strip colors if the target channel restricts them
|
|
if target.lower() in self._colorless_channels:
|
|
text = strip_formatting(text)
|
|
|
|
prefix = f"PRIVMSG {target} :"
|
|
limit = 510 - len(prefix.encode())
|
|
encoded = text.encode("utf-8", errors="replace")
|
|
while encoded:
|
|
chunk, encoded = encoded[:limit], encoded[limit:]
|
|
self.send(prefix + chunk.decode("utf-8", errors="replace"))
|
|
|
|
def join(self, channel: str):
|
|
self.send(f"JOIN {channel}")
|
|
|
|
def in_channel(self, channel: str) -> bool:
|
|
return channel.lower() in self._channels
|
|
|
|
# ── Internal ──────────────────────────────────────────────────────────────
|
|
|
|
async def _connect(self):
|
|
host = self.config["host"]
|
|
port = self.config["port"]
|
|
use_tls = self.config.get("tls", False)
|
|
|
|
log.info("[%s] Connecting to %s:%d (tls=%s)…", self.name, host, port, use_tls)
|
|
|
|
bind = self.config.get("bind") or self.config.get("_global_bind")
|
|
local_addr = (bind, 0) if bind else None
|
|
|
|
if use_tls:
|
|
ctx = ssl.create_default_context()
|
|
if not self.config.get("tls_verify", True):
|
|
ctx.check_hostname = False
|
|
ctx.verify_mode = ssl.CERT_NONE
|
|
reader, writer = await asyncio.open_connection(
|
|
host, port, ssl=ctx, local_addr=local_addr)
|
|
else:
|
|
reader, writer = await asyncio.open_connection(
|
|
host, port, local_addr=local_addr)
|
|
|
|
self._writer = writer
|
|
self._channels = set()
|
|
self._colorless_channels = set()
|
|
self._ready = False
|
|
|
|
nick = self.config["nickname"]
|
|
username = self.config.get("username", nick)
|
|
realname = self.config.get("realname", nick)
|
|
|
|
if "sasl_plain" in self.config:
|
|
self.send("CAP REQ :sasl")
|
|
|
|
self.send(f"NICK {nick}")
|
|
self.send(f"USER {username} 0 * :{realname}")
|
|
|
|
try:
|
|
async for raw in self._read_lines(reader):
|
|
await self._handle(raw)
|
|
finally:
|
|
writer.close()
|
|
self._writer = None
|
|
self._ready = False
|
|
|
|
async def _read_lines(self, reader: asyncio.StreamReader):
|
|
buf = b""
|
|
while True:
|
|
try:
|
|
data = await asyncio.wait_for(reader.read(4096), timeout=300)
|
|
except asyncio.TimeoutError:
|
|
log.warning("[%s] Read timeout, disconnecting", self.name)
|
|
return
|
|
if not data:
|
|
log.info("[%s] Connection closed by server", self.name)
|
|
return
|
|
buf += data
|
|
while b"\n" in buf:
|
|
line, buf = buf.split(b"\n", 1)
|
|
yield line.rstrip(b"\r").decode("utf-8", errors="replace")
|
|
|
|
async def _handle(self, raw: str):
|
|
log.debug("[%s] << %s", self.name, raw)
|
|
|
|
if raw.startswith("PING"):
|
|
token = raw.split(":", 1)[-1] if ":" in raw else raw.split(" ", 1)[-1]
|
|
self.send(f"PONG :{token}")
|
|
return
|
|
|
|
parts = raw.split(" ")
|
|
|
|
# CAP negotiation for SASL
|
|
if len(parts) >= 3 and parts[1] == "CAP":
|
|
sub = parts[3] if len(parts) > 3 else ""
|
|
if sub in ("ACK", ":ACK") and "sasl" in raw:
|
|
self.send("AUTHENTICATE PLAIN")
|
|
return
|
|
|
|
if len(parts) >= 2 and parts[0] == "AUTHENTICATE":
|
|
sasl = self.config["sasl_plain"]
|
|
token = base64.b64encode(
|
|
f"\x00{sasl['user']}\x00{sasl['password']}".encode()
|
|
).decode()
|
|
self.send(f"AUTHENTICATE {token}")
|
|
return
|
|
|
|
# 001 = welcome
|
|
if len(parts) >= 2 and parts[1] == "001":
|
|
self._reconnect_delay = RECONNECT_DELAY_MIN
|
|
if "nickserv_password" in self.config:
|
|
ns_pw = self.config["nickserv_password"]
|
|
self.send(f"PRIVMSG NickServ :IDENTIFY {ns_pw}")
|
|
for ch in self.config.get("channels", []):
|
|
self.join(ch)
|
|
self._ready = True
|
|
await self.on_connected(self.name)
|
|
return
|
|
|
|
# MODE detection (+c / +S)
|
|
if len(parts) >= 4 and parts[1] == "MODE":
|
|
target = parts[2]
|
|
mode_change = parts[3]
|
|
if target.startswith("#"):
|
|
if "+c" in mode_change or "+S" in mode_change:
|
|
self._colorless_channels.add(target.lower())
|
|
elif "-c" in mode_change or "-S" in mode_change:
|
|
self._colorless_channels.discard(target.lower())
|
|
|
|
# 324 = Channel mode discovery (received after JOIN or manual MODE request)
|
|
if len(parts) >= 5 and parts[1] == "324":
|
|
channel = parts[3]
|
|
modes = parts[4]
|
|
if "c" in modes or "S" in modes:
|
|
self._colorless_channels.add(channel.lower())
|
|
|
|
# JOIN
|
|
if len(parts) >= 3 and parts[1] == "JOIN":
|
|
channel = parts[2].lstrip(":")
|
|
joiner = parts[0].lstrip(":").split("!")[0]
|
|
if joiner.lower() == self.config["nickname"].lower():
|
|
self._channels.add(channel.lower())
|
|
self.send(f"MODE {channel}") # Request modes immediately
|
|
return
|
|
|
|
# PRIVMSG
|
|
if len(parts) >= 4 and parts[1] == "PRIVMSG":
|
|
full_prefix = parts[0].lstrip(":")
|
|
nick = full_prefix.split("!")[0]
|
|
target = parts[2]
|
|
text = " ".join(parts[3:]).lstrip(":")
|
|
await self.on_message(self.name, target, nick, full_prefix, text)
|
|
return
|