You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

226 lines
8.4 KiB

import asyncio
import importlib
import json
import os
import signal
import sys
import time
import traceback
from pathlib import Path
from typing import Any, Iterable, Optional
import irc.client
import irc.client_aio
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from irc.client import Connection, Event, NickMask
4 years ago
from edmond.log import Logger
from edmond.plugin import Plugin
4 years ago
class Bot(irc.client_aio.AioSimpleIRCClient, Logger):
"""Main class for the IRC bot: handles connection and manages plugins."""
4 years ago
CHANNELS_RUNTIME_KEY = "_channels"
def __init__(self, config: dict, logger):
4 years ago
super().__init__()
self.config: dict = config
4 years ago
self.logger = logger
self.plugins: list[Plugin] = []
self.values: dict[str, Any] = {}
self.storage: dict[str, Any] = self.get_storage()
self.tasks: list[asyncio.Task] = []
self.done: bool = False
self.scheduler = AsyncIOScheduler()
self.scheduler.start()
4 years ago
@property
def nick(self) -> str:
4 years ago
"""Nickname validated by the server, or the configured nick."""
if self.connection.is_connected():
return self.connection.get_nickname()
4 years ago
return self.config["nick"]
@property
def names(self) -> Iterable[str]:
4 years ago
"""Collection of names the bot should identify with."""
return (self.nick, *self.config["alternative_nicks"])
@property
def channels(self) -> list[str]:
"""List of joined channels."""
if self.CHANNELS_RUNTIME_KEY not in self.values:
self.values[self.CHANNELS_RUNTIME_KEY] = []
return self.values[self.CHANNELS_RUNTIME_KEY]
def get_storage(self) -> dict:
"""Load data from storage."""
try:
with open(self.config["storage_file"], "rt") as storage_file:
storage = json.load(storage_file)
self.log_d("Loaded storage file.")
return storage
except (OSError, json.decoder.JSONDecodeError) as exc:
self.log_e(f"Could not load storage file: {exc}")
self.log_w(
"If it's not the first time Edm0nd is run, you may lose"
" data when closing the program."
)
return {}
def save_storage(self) -> None:
"""Save storage data to disk."""
try:
with open(self.config["storage_file"], "wt") as storage_file:
json.dump(self.storage, storage_file, indent=2, sort_keys=True)
self.log_d("Saved storage file.")
except OSError as exc:
self.log_e(f"Could not save storage file: {exc}")
def handle_task(self, coro):
"""Schedule a task in the event loop. Keep a reference to cancel it."""
task = self.connection.reactor.loop.create_task(coro)
self.tasks.append(task)
def on_welcome(self, connection: Connection, event: Event):
4 years ago
"""Handle a successful connection to a server."""
4 years ago
self.log_i(f"Connected to server {event.source}.")
self.run_plugin_callbacks(event)
4 years ago
for channel in self.config["channels"]:
connection.join(channel)
def on_join(self, connection: Connection, event: Event):
4 years ago
"""Handle someone, possibly the bot, joining a channel."""
if event.source.nick == self.nick:
self.log_i(f"Joined {event.target}.")
self.channels.append(event.target)
self.run_plugin_callbacks(event)
4 years ago
def on_part(self, connection: Connection, event: Event):
4 years ago
"""Handle someone, possibly the bot, leaving a channel."""
if event.source.nick == self.nick:
self.log_i(f"Left {event.target} (args: {event.arguments[0]}).")
self.channels.remove(event.target)
self.run_plugin_callbacks(event)
4 years ago
def on_pubmsg(self, connection: Connection, event: Event):
4 years ago
"""Handle a message received in a channel."""
4 years ago
channel = event.target
nick = NickMask(event.source).nick
message = event.arguments[0]
self.log_d(f"Message in {channel} from {nick}: {message}")
self.run_plugin_callbacks(event)
4 years ago
def on_privmsg(self, connection: Connection, event: Event):
"""Handle a message received privately, usually like a channel msg."""
4 years ago
nick = NickMask(event.source).nick
target = event.target
message = event.arguments[0]
self.log_d(f"Private message from {nick} to {target}: {message}")
self.run_plugin_callbacks(event)
4 years ago
def on_ping(self, connection: Connection, event: Event):
"""Handle a ping; can be used as a random event timer."""
self.log_d(f"Received ping from {event.target}.")
self.run_plugin_callbacks(event)
4 years ago
def run(self):
"""Connect the bot to server, join channels and start responding."""
4 years ago
self.log_i("Starting Edmond.")
self.load_plugins()
self.log_i("Connecting to server…")
signal.signal(signal.SIGTERM, self.handle_sigterm)
4 years ago
try:
self.connect(self.config["host"], self.config["port"], self.nick)
4 years ago
self.start()
except irc.client.ServerConnectionError as exc:
self.log_c(f"Connection failed: {exc}")
4 years ago
except KeyboardInterrupt:
self.log_i("Caught keyboard interrupt.")
except Exception as exc:
self.log_c(f"Caught unhandled {type(exc).__name__}: {exc}")
_, _, exc_traceback = sys.exc_info()
for line in traceback.format_tb(exc_traceback):
self.log_d(line.rstrip())
finally:
self.cleanup()
def load_plugins(self):
"""Load all installed plugins."""
self.log_i("Loading plugins…")
plugin_files = os.listdir(Path(__file__).parent / "plugins")
plugin_names = map(
lambda f: os.path.splitext(f)[0],
filter(
lambda f: f.endswith(".py") and f != "__init__.py",
plugin_files,
),
)
for plugin_name in plugin_names:
module = importlib.import_module(f"edmond.plugins.{plugin_name}")
# Get plugin class name from its module name.
class_name = (
"".join(map(lambda w: w.capitalize(), plugin_name.split("_")))
+ "Plugin"
)
plugin_class = getattr(module, class_name)
self.plugins.append(plugin_class(self))
self.values[plugin_name] = {}
self.log_d(f"Loaded {class_name}.")
def get_plugin(self, name: str) -> Optional[Plugin]:
"""Get a loaded plugin by its name (e.g. 'mood'), or None."""
matching_plugins = filter(
lambda plugin: plugin.name == name,
self.plugins,
)
return next(matching_plugins, None)
def say(self, target: str, message: str) -> None:
"""Send message to target after a slight delay."""
message = message.replace("\n", " ").replace("\r", " ")
time.sleep(self.config["speak_delay"])
self.log_d(f"Sending to {target}: {message}")
try:
if message.startswith("/me "):
self.connection.action(target, message[4:])
else:
self.connection.privmsg(target, message)
except irc.client.MessageTooLong:
self.log_e("Could not send, message is too long.")
def run_plugin_callbacks(self, event: Event) -> None:
"""Run appropriate callbacks for each plugin."""
etype = event.type
ready_plugins = filter(lambda p: p.is_ready, self.plugins)
plugins = sorted(ready_plugins, key=lambda p: p.priority, reverse=True)
for plugin in plugins:
callbacks = plugin.callbacks
if etype not in callbacks:
continue
if callbacks[etype](event):
break
def handle_sigterm(self, *args):
"""Handle SIGTERM (keyboard interrupt, systemd stop, etc)."""
self.cleanup()
exit("Exiting after received SIGTERM.")
def cleanup(self) -> None:
"""Save the storage file and close the connection. Run only once."""
if self.done:
return
self.log_i("Stopping Edmond.")
self.save_storage() # FIRST THINGS FIRST
for task in self.tasks:
if not task.cancelled():
self.log_d(f"Cancelling task {task.get_name()}")
task.cancel()
if self.connection.is_connected():
self.connection.close()
self.reactor.loop.close()
self.done = True