gpt3: add plugin

This commit is contained in:
dece 2022-11-29 12:53:47 +01:00
parent 84f4eaeec9
commit d2e3b7503b
5 changed files with 70 additions and 18 deletions

View file

@ -19,3 +19,4 @@ translate = "*"
scaruffi = "==0.0.3" scaruffi = "==0.0.3"
wolframalpha = "*" wolframalpha = "*"
meteofrance-api = "~=1.0.2" meteofrance-api = "~=1.0.2"
openai = "~=0.25.0"

View file

@ -70,6 +70,8 @@
}, },
"gpt3": { "gpt3": {
"openai_key": "" "openai_key": ""
"join_lines": "; ",
"computing_replies": ["Hmm…"]
}, },
"horoscope": { "horoscope": {
"commands": ["horoscope"], "commands": ["horoscope"],

48
edmond/plugins/gpt3.py Normal file
View file

@ -0,0 +1,48 @@
import random
import re
from typing import Optional
import openai
from edmond.plugin import Plugin
from edmond.utils import limit_text_length
class Gpt3Plugin(Plugin):
REQUIRED_CONFIGS = ["openai_key", "computing_replies", "join_lines"]
def __init__(self, bot):
super().__init__(bot)
openai.api_key = self.config["openai_key"]
def reply(self, prompt: str, target: str):
computing_reply = random.choice(self.config["computing_replies"])
self.bot.say(target, computing_reply)
completion = self.complete(prompt)
if completion and (reply := self.sanitize(completion)):
self.bot.say(target, reply)
else:
self.bot.signal_failure(target)
def complete(self, prompt: str) -> Optional[str]:
try:
completion = openai.Completion.create(
model="text-davinci-002",
prompt=prompt,
temperature=0.7,
max_tokens=128,
top_p=1,
frequency_penalty=0.5,
presence_penalty=0
)
except openai.error.OpenAIError:
return None
return completion.choices[0].text
def sanitize(self, text: str) -> str:
text = text.strip()
text = re.sub(r"\n+", self.config["join_lines"], text)
text = limit_text_length(text)
return text

View file

@ -5,6 +5,7 @@ import wikipedia
from edmond.plugin import Plugin from edmond.plugin import Plugin
from edmond.plugins.plus import PlusPlugin from edmond.plugins.plus import PlusPlugin
from edmond.utils import limit_text_length
class WikipediaPlugin(Plugin): class WikipediaPlugin(Plugin):
@ -47,7 +48,7 @@ class WikipediaPlugin(Plugin):
self.bot.log_d(f"Wikipedia exception: {exc}") self.bot.log_d(f"Wikipedia exception: {exc}")
retries -= 1 retries -= 1
if page: if page:
reply = WikipediaPlugin.limit_text_length(page.summary) reply = limit_text_length(page.summary)
self.register_url_for_plus(page.url, event.target) self.register_url_for_plus(page.url, event.target)
self.bot.say(event.target, reply) self.bot.say(event.target, reply)
@ -71,7 +72,7 @@ class WikipediaPlugin(Plugin):
time.sleep(1) time.sleep(1)
retries -= 1 retries -= 1
if page: if page:
reply = WikipediaPlugin.limit_text_length(page.summary) reply = limit_text_length(page.summary)
self.register_url_for_plus(page.url, event.target) self.register_url_for_plus(page.url, event.target)
self.bot.say(event.target, reply) self.bot.say(event.target, reply)
@ -80,19 +81,3 @@ class WikipediaPlugin(Plugin):
def handler(plus_event): def handler(plus_event):
self.bot.say(plus_event.target, url) self.bot.say(plus_event.target, url)
cast(PlusPlugin, plus_plugin).add_handler(target, handler) cast(PlusPlugin, plus_plugin).add_handler(target, handler)
@staticmethod
def limit_text_length(text, max_length=200):
"""Limit text size to 200 characters max."""
words = text.split(" ")
cut_text = ""
while words:
next_word = words.pop(0)
if len(cut_text) + len(next_word) + 1 >= max_length:
break
cut_text += next_word + " "
if len(cut_text) < len(text):
cut_text = cut_text[:-1] + ""
else:
cut_text = cut_text.rstrip()
return cut_text

View file

@ -13,3 +13,19 @@ def http_get(url: str) -> Optional[str]:
def proc(proba_percentage: int) -> bool: def proc(proba_percentage: int) -> bool:
return random.random() < (proba_percentage / 100.0) return random.random() < (proba_percentage / 100.0)
def limit_text_length(text, max_length=400):
"""Limit text size to 400 characters max."""
words = text.split(" ")
cut_text = ""
while words:
next_word = words.pop(0)
if len(cut_text) + len(next_word) + 1 >= max_length:
break
cut_text += next_word + " "
if len(cut_text) < len(text):
cut_text = cut_text[:-1] + ""
else:
cut_text = cut_text.rstrip()
return cut_text