Add basic chain with memory

This commit is contained in:
Tyler Perkins 2023-05-24 00:06:53 -04:00
parent fe5f5e2d3a
commit fd29965e5b
6 changed files with 111 additions and 18 deletions

View File

@ -8,3 +8,6 @@ nlp:
token: "${OPENAI_TOKEN}"
model: "gpt-3.5-turbo"
api:
serpapi: "${SERPAPI_TOKEN}"

View File

@ -4,7 +4,8 @@ from config import config
import nlp
import logging
completion = nlp.CompletionOpenAI(config["nlp.token"])
completion = nlp.CompletionOpenAI(token=config["nlp.token"],
serpapi=config["api.serpapi"])
async def parser(client, message: str):
"""

View File

@ -21,10 +21,9 @@ async def on_ready():
@client.event
async def on_message(message):
print(message)
logging.info(f"Message received : {message.content}")
if message.author == client.user:
return
#if message.author == client.user:
# return
if len(message.content) == 0: # Ignore empty messages
return
#if message.content.startswith(config["discord.prefix"]):

View File

@ -1,9 +1,24 @@
from langchain.agents import load_tools
from langchain.agents import ZeroShotAgent, Tool, AgentExecutor
from langchain.memory import ConversationBufferMemory
from langchain.agents import initialize_agent
from langchain.agents import AgentType
from langchain import LLMChain
class CompletionMeta(type):
"""
The meta class for completion interface
"""
tools = None
class Completion():
def __init__(self):
self.prefix = None
self.suffix = None
self.promptVars = None
def complete(self, message: str) -> str:
"""
Perform a text completion using the language model
@ -15,3 +30,52 @@ class Completion():
Return the model name
"""
pass
def getChain(self, llm):
pass
def getAgent(self, llm):
# Load tools
tools = load_tools(["serpapi", "llm-math"], llm=llm)
# Build a prompt
prompt = ZeroShotAgent.create_prompt(
tools=tools,
prefix=self.getPromptPrefix(),
suffix=self.getPromptSuffix(),
input_variables=self.getPromptVariables()
)
memory = ConversationBufferMemory(memory_key="chat_history")
# Build the LLM Chain
llm_chain = LLMChain(llm=llm, prompt=prompt)
agent = ZeroShotAgent(llm_chain=llm_chain,
tools=tools,
verbose=True)
self.agent_chain = AgentExecutor.from_agent_and_tools(agent=agent,
tools=tools,
verbose=True,
memory=memory)
return self.agent_chain
def getPromptVariables(self):
if self.promptVars is None:
self.promptVars = ["input", "chat_history", "agent_scratchpad"]
return self.promptVars
def getPromptPrefix(self):
if self.prefix is None:
self.prefix = """Have a conversation with a human, answering the following
questions as best you can. You have access to the following tools:"""
return self.prefix
def getPromptSuffix(self):
if self.suffix is None:
self.suffix = """Begin!
{chat_history}
Question: {input}
{agent_scratchpad}"""
return self.suffix

View File

@ -1,23 +1,26 @@
import openai
from .Completion import Completion
from langchain.llms import OpenAI
import os
class CompletionOpenAI(Completion):
# Constructor
def __init__(self, token: str, model: str = "gpt-3.5-turbo"):
def __init__(self,
token: str,
serpapi: str,
temperature: float = 0.5):
super().__init__()
os.environ["SERPAPI_API_KEY"] = serpapi
self.token = token
self.model = model
openai.api_key = self.token
self.temperature = temperature
self.llm = OpenAI(temperature=self.temperature,
openai_api_key=self.token)
self.agent = self.getAgent(self.llm)
def complete(self, message: str) -> str:
chat_completion = openai.ChatCompletion.create(
model=self.model,
messages=[
{
"role": "user",
"content": message
}
])
return chat_completion.choices[0].message.content
return self.agent_chain.run(message)
def getModelName(self) -> str:
return self.model
return ""

View File

@ -2,10 +2,33 @@ aiohttp==3.8.4
aiosignal==1.3.1
async-timeout==4.0.2
attrs==23.1.0
certifi==2023.5.7
charset-normalizer==3.1.0
dataclasses-json==0.5.7
discord==2.2.3
discord.py==2.2.3
envyaml==1.10.211231
frozenlist==1.3.3
google-search-results==2.4.2
greenlet==2.0.2
idna==3.4
langchain==0.0.178
marshmallow==3.19.0
marshmallow-enum==1.5.1
multidict==6.0.4
mypy-extensions==1.0.0
numexpr==2.8.4
numpy==1.24.3
openai==0.27.7
openapi-schema-pydantic==1.2.4
packaging==23.1
pydantic==1.10.8
PyYAML==6.0
requests==2.31.0
SQLAlchemy==2.0.15
tenacity==8.2.2
tqdm==4.65.0
typing-inspect==0.8.0
typing_extensions==4.6.0
urllib3==2.0.2
yarl==1.9.2