Edm0nd/edmond/plugins/kagi_fastgpt.py
2024-03-08 18:22:56 +01:00

82 lines
2.5 KiB
Python

import json
import random
import re
from typing import Optional, Tuple, cast
import requests
from edmond.plugin import Plugin
from edmond.plugins.plus import PlusPlugin
from edmond.utils import limit_text_length
class KagiFastgptPlugin(Plugin):
BASE_URL = "https://kagi.com/api/v0/fastgpt"
REQUIRED_CONFIGS = ["api_key"]
def __init__(self, bot):
super().__init__(bot)
self.api_key = self.config["api_key"]
self.prompt = self.config.get("prompt", "")
def on_welcome(self, _):
if not self.api_key:
self.bot.log_w("Kagi FastGPT API key unavailable.")
self.is_ready = False
def reply(self, query: str, target: str):
computing_reply = random.choice(self.config["computing_replies"])
self.bot.say(target, computing_reply)
output, references = self.complete(query)
if output:
self.bot.say(target, self.sanitize(output))
self.register_references_for_plus(references, target)
else:
self.signal_failure(target)
def complete(self, query: str) -> Tuple[Optional[str], list]:
try:
response = requests.post(
self.BASE_URL,
headers={"Authorization": f"Bot {self.api_key}"},
json={
"query": self.prompt + query
}
)
except requests.RequestException as exc:
self.bot.log_e(f"Request error: {exc}")
return None, []
data = response.json().get("data")
if not data:
self.bot.log_w("No data at all, no more tokens available?")
return None, []
self.bot.log_d(f"Data received: {json.dumps(data)}")
output = data.get("output", "")
if not output:
self.bot.log_w("Empty FastGPT output!")
return None, []
references = data.get("references", [])
return output, references
def register_references_for_plus(
self,
references: list[str],
target: str
) -> None:
if references and (plus_plugin := self.bot.get_plugin("plus")):
def handler(plus_event):
for ref in references[:3]:
message = ref["title"] + " " + ref["url"]
self.bot.say(plus_event.target, message)
cast(PlusPlugin, plus_plugin).add_handler(target, handler)
def sanitize(self, text: str) -> str:
text = text.strip()
text = re.sub(r"\n+", "", text)
text = limit_text_length(text)
return text