1
0
mirror of https://github.com/TehPeGaSuS/GitBot.git synced 2026-06-25 09:05:46 +02:00
Files
GitBot/irc_client.py
TehPeGaSuS deb736254d Enhance IRC client with color stripping and mode handling
Added automatic color stripping for channels that restrict colors and improved mode handling.
2026-05-05 15:26:37 +02:00

216 lines
8.0 KiB
Python

"""Async IRC client with automatic mode handling and color stripping."""
import asyncio
import logging
import ssl
import re
import base64
from typing import Callable, Optional
log = logging.getLogger("irc")
RECONNECT_DELAY_MIN = 5
RECONNECT_DELAY_MAX = 300
def strip_formatting(text: str) -> str:
"""Removes all IRC color and style control codes (Bold, Color, Underline, etc)."""
# Regex handles \x02 (bold), \x03 (color), \x0f (reset), \x1f (underline), etc.
return re.sub(r'\x02|\x1f|\x16|\x1d|\x0f|\x03(?:\d{1,2}(?:,\d{1,2})?)?', '', text)
class IRCClient:
def __init__(self, config: dict, on_message: Callable,
on_connected: Callable):
self.config = config
self.name = config["name"]
self.on_message = on_message # async fn(network, channel, nick, msg)
self.on_connected = on_connected # async fn(network)
self._writer: Optional[asyncio.StreamWriter] = None
self._channels: set = set()
self._colorless_channels: set = set() # Channels with +c or +S active
self._reconnect_delay = RECONNECT_DELAY_MIN
self._running = True
self._ready = False
# ── Public API ────────────────────────────────────────────────────────────
async def run(self):
while self._running:
try:
await self._connect()
except Exception as e:
log.warning("[%s] Connection error: %s", self.name, e)
if not self._running:
break
log.info("[%s] Reconnecting in %ds…", self.name, self._reconnect_delay)
await asyncio.sleep(self._reconnect_delay)
self._reconnect_delay = min(
self._reconnect_delay * 2, RECONNECT_DELAY_MAX)
async def stop(self, message: str = "Disconnecting"):
self._running = False
if self._writer:
try:
self.send(f"QUIT :{message}")
await asyncio.sleep(0.5)
self._writer.close()
except Exception:
pass
def send(self, line: str):
if self._writer:
log.debug("[%s] >> %s", self.name, line)
self._writer.write((line + "\r\n").encode())
def privmsg(self, target: str, text: str):
# Automatically strip colors if the target channel restricts them
if target.lower() in self._colorless_channels:
text = strip_formatting(text)
prefix = f"PRIVMSG {target} :"
limit = 510 - len(prefix.encode())
encoded = text.encode("utf-8", errors="replace")
while encoded:
chunk, encoded = encoded[:limit], encoded[limit:]
self.send(prefix + chunk.decode("utf-8", errors="replace"))
def join(self, channel: str):
self.send(f"JOIN {channel}")
def in_channel(self, channel: str) -> bool:
return channel.lower() in self._channels
# ── Internal ──────────────────────────────────────────────────────────────
async def _connect(self):
host = self.config["host"]
port = self.config["port"]
use_tls = self.config.get("tls", False)
log.info("[%s] Connecting to %s:%d (tls=%s)…", self.name, host, port, use_tls)
bind = self.config.get("bind") or self.config.get("_global_bind")
local_addr = (bind, 0) if bind else None
if use_tls:
ctx = ssl.create_default_context()
if not self.config.get("tls_verify", True):
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE
reader, writer = await asyncio.open_connection(
host, port, ssl=ctx, local_addr=local_addr)
else:
reader, writer = await asyncio.open_connection(
host, port, local_addr=local_addr)
self._writer = writer
self._channels = set()
self._colorless_channels = set()
self._ready = False
nick = self.config["nickname"]
username = self.config.get("username", nick)
realname = self.config.get("realname", nick)
if "sasl_plain" in self.config:
self.send("CAP REQ :sasl")
self.send(f"NICK {nick}")
self.send(f"USER {username} 0 * :{realname}")
try:
async for raw in self._read_lines(reader):
await self._handle(raw)
finally:
writer.close()
self._writer = None
self._ready = False
async def _read_lines(self, reader: asyncio.StreamReader):
buf = b""
while True:
try:
data = await asyncio.wait_for(reader.read(4096), timeout=300)
except asyncio.TimeoutError:
log.warning("[%s] Read timeout, disconnecting", self.name)
return
if not data:
log.info("[%s] Connection closed by server", self.name)
return
buf += data
while b"\n" in buf:
line, buf = buf.split(b"\n", 1)
yield line.rstrip(b"\r").decode("utf-8", errors="replace")
async def _handle(self, raw: str):
log.debug("[%s] << %s", self.name, raw)
if raw.startswith("PING"):
token = raw.split(":", 1)[-1] if ":" in raw else raw.split(" ", 1)[-1]
self.send(f"PONG :{token}")
return
parts = raw.split(" ")
# CAP negotiation for SASL
if len(parts) >= 3 and parts[1] == "CAP":
sub = parts[3] if len(parts) > 3 else ""
if sub in ("ACK", ":ACK") and "sasl" in raw:
self.send("AUTHENTICATE PLAIN")
return
if len(parts) >= 2 and parts[0] == "AUTHENTICATE":
sasl = self.config["sasl_plain"]
token = base64.b64encode(
f"\x00{sasl['user']}\x00{sasl['password']}".encode()
).decode()
self.send(f"AUTHENTICATE {token}")
return
# 001 = welcome
if len(parts) >= 2 and parts[1] == "001":
self._reconnect_delay = RECONNECT_DELAY_MIN
if "nickserv_password" in self.config:
ns_pw = self.config["nickserv_password"]
self.send(f"PRIVMSG NickServ :IDENTIFY {ns_pw}")
for ch in self.config.get("channels", []):
self.join(ch)
self._ready = True
await self.on_connected(self.name)
return
# MODE detection (+c / +S)
if len(parts) >= 4 and parts[1] == "MODE":
target = parts[2]
mode_change = parts[3]
if target.startswith("#"):
if "+c" in mode_change or "+S" in mode_change:
self._colorless_channels.add(target.lower())
elif "-c" in mode_change or "-S" in mode_change:
self._colorless_channels.discard(target.lower())
# 324 = Channel mode discovery (received after JOIN or manual MODE request)
if len(parts) >= 5 and parts[1] == "324":
channel = parts[3]
modes = parts[4]
if "c" in modes or "S" in modes:
self._colorless_channels.add(channel.lower())
# JOIN
if len(parts) >= 3 and parts[1] == "JOIN":
channel = parts[2].lstrip(":")
joiner = parts[0].lstrip(":").split("!")[0]
if joiner.lower() == self.config["nickname"].lower():
self._channels.add(channel.lower())
self.send(f"MODE {channel}") # Request modes immediately
return
# PRIVMSG
if len(parts) >= 4 and parts[1] == "PRIVMSG":
full_prefix = parts[0].lstrip(":")
nick = full_prefix.split("!")[0]
target = parts[2]
text = " ".join(parts[3:]).lstrip(":")
await self.on_message(self.name, target, nick, full_prefix, text)
return