gpt3: add plugin
This commit is contained in:
parent
84f4eaeec9
commit
d2e3b7503b
1
Pipfile
1
Pipfile
|
@ -19,3 +19,4 @@ translate = "*"
|
||||||
scaruffi = "==0.0.3"
|
scaruffi = "==0.0.3"
|
||||||
wolframalpha = "*"
|
wolframalpha = "*"
|
||||||
meteofrance-api = "~=1.0.2"
|
meteofrance-api = "~=1.0.2"
|
||||||
|
openai = "~=0.25.0"
|
||||||
|
|
|
@ -70,6 +70,8 @@
|
||||||
},
|
},
|
||||||
"gpt3": {
|
"gpt3": {
|
||||||
"openai_key": ""
|
"openai_key": ""
|
||||||
|
"join_lines": "; ",
|
||||||
|
"computing_replies": ["Hmm…"]
|
||||||
},
|
},
|
||||||
"horoscope": {
|
"horoscope": {
|
||||||
"commands": ["horoscope"],
|
"commands": ["horoscope"],
|
||||||
|
|
48
edmond/plugins/gpt3.py
Normal file
48
edmond/plugins/gpt3.py
Normal file
|
@ -0,0 +1,48 @@
|
||||||
|
import random
|
||||||
|
import re
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import openai
|
||||||
|
|
||||||
|
from edmond.plugin import Plugin
|
||||||
|
from edmond.utils import limit_text_length
|
||||||
|
|
||||||
|
|
||||||
|
class Gpt3Plugin(Plugin):
|
||||||
|
|
||||||
|
REQUIRED_CONFIGS = ["openai_key", "computing_replies", "join_lines"]
|
||||||
|
|
||||||
|
def __init__(self, bot):
|
||||||
|
super().__init__(bot)
|
||||||
|
openai.api_key = self.config["openai_key"]
|
||||||
|
|
||||||
|
def reply(self, prompt: str, target: str):
|
||||||
|
computing_reply = random.choice(self.config["computing_replies"])
|
||||||
|
self.bot.say(target, computing_reply)
|
||||||
|
|
||||||
|
completion = self.complete(prompt)
|
||||||
|
if completion and (reply := self.sanitize(completion)):
|
||||||
|
self.bot.say(target, reply)
|
||||||
|
else:
|
||||||
|
self.bot.signal_failure(target)
|
||||||
|
|
||||||
|
def complete(self, prompt: str) -> Optional[str]:
|
||||||
|
try:
|
||||||
|
completion = openai.Completion.create(
|
||||||
|
model="text-davinci-002",
|
||||||
|
prompt=prompt,
|
||||||
|
temperature=0.7,
|
||||||
|
max_tokens=128,
|
||||||
|
top_p=1,
|
||||||
|
frequency_penalty=0.5,
|
||||||
|
presence_penalty=0
|
||||||
|
)
|
||||||
|
except openai.error.OpenAIError:
|
||||||
|
return None
|
||||||
|
return completion.choices[0].text
|
||||||
|
|
||||||
|
def sanitize(self, text: str) -> str:
|
||||||
|
text = text.strip()
|
||||||
|
text = re.sub(r"\n+", self.config["join_lines"], text)
|
||||||
|
text = limit_text_length(text)
|
||||||
|
return text
|
|
@ -5,6 +5,7 @@ import wikipedia
|
||||||
|
|
||||||
from edmond.plugin import Plugin
|
from edmond.plugin import Plugin
|
||||||
from edmond.plugins.plus import PlusPlugin
|
from edmond.plugins.plus import PlusPlugin
|
||||||
|
from edmond.utils import limit_text_length
|
||||||
|
|
||||||
|
|
||||||
class WikipediaPlugin(Plugin):
|
class WikipediaPlugin(Plugin):
|
||||||
|
@ -47,7 +48,7 @@ class WikipediaPlugin(Plugin):
|
||||||
self.bot.log_d(f"Wikipedia exception: {exc}")
|
self.bot.log_d(f"Wikipedia exception: {exc}")
|
||||||
retries -= 1
|
retries -= 1
|
||||||
if page:
|
if page:
|
||||||
reply = WikipediaPlugin.limit_text_length(page.summary)
|
reply = limit_text_length(page.summary)
|
||||||
self.register_url_for_plus(page.url, event.target)
|
self.register_url_for_plus(page.url, event.target)
|
||||||
self.bot.say(event.target, reply)
|
self.bot.say(event.target, reply)
|
||||||
|
|
||||||
|
@ -71,7 +72,7 @@ class WikipediaPlugin(Plugin):
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
retries -= 1
|
retries -= 1
|
||||||
if page:
|
if page:
|
||||||
reply = WikipediaPlugin.limit_text_length(page.summary)
|
reply = limit_text_length(page.summary)
|
||||||
self.register_url_for_plus(page.url, event.target)
|
self.register_url_for_plus(page.url, event.target)
|
||||||
self.bot.say(event.target, reply)
|
self.bot.say(event.target, reply)
|
||||||
|
|
||||||
|
@ -80,19 +81,3 @@ class WikipediaPlugin(Plugin):
|
||||||
def handler(plus_event):
|
def handler(plus_event):
|
||||||
self.bot.say(plus_event.target, url)
|
self.bot.say(plus_event.target, url)
|
||||||
cast(PlusPlugin, plus_plugin).add_handler(target, handler)
|
cast(PlusPlugin, plus_plugin).add_handler(target, handler)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def limit_text_length(text, max_length=200):
|
|
||||||
"""Limit text size to 200 characters max."""
|
|
||||||
words = text.split(" ")
|
|
||||||
cut_text = ""
|
|
||||||
while words:
|
|
||||||
next_word = words.pop(0)
|
|
||||||
if len(cut_text) + len(next_word) + 1 >= max_length:
|
|
||||||
break
|
|
||||||
cut_text += next_word + " "
|
|
||||||
if len(cut_text) < len(text):
|
|
||||||
cut_text = cut_text[:-1] + "…"
|
|
||||||
else:
|
|
||||||
cut_text = cut_text.rstrip()
|
|
||||||
return cut_text
|
|
||||||
|
|
|
@ -13,3 +13,19 @@ def http_get(url: str) -> Optional[str]:
|
||||||
|
|
||||||
def proc(proba_percentage: int) -> bool:
|
def proc(proba_percentage: int) -> bool:
|
||||||
return random.random() < (proba_percentage / 100.0)
|
return random.random() < (proba_percentage / 100.0)
|
||||||
|
|
||||||
|
|
||||||
|
def limit_text_length(text, max_length=400):
|
||||||
|
"""Limit text size to 400 characters max."""
|
||||||
|
words = text.split(" ")
|
||||||
|
cut_text = ""
|
||||||
|
while words:
|
||||||
|
next_word = words.pop(0)
|
||||||
|
if len(cut_text) + len(next_word) + 1 >= max_length:
|
||||||
|
break
|
||||||
|
cut_text += next_word + " "
|
||||||
|
if len(cut_text) < len(text):
|
||||||
|
cut_text = cut_text[:-1] + "…"
|
||||||
|
else:
|
||||||
|
cut_text = cut_text.rstrip()
|
||||||
|
return cut_text
|
||||||
|
|
Loading…
Reference in a new issue