Add basic chain with memory
This commit is contained in:
parent
fe5f5e2d3a
commit
fd29965e5b
|
@ -8,3 +8,6 @@ nlp:
|
||||||
token: "${OPENAI_TOKEN}"
|
token: "${OPENAI_TOKEN}"
|
||||||
model: "gpt-3.5-turbo"
|
model: "gpt-3.5-turbo"
|
||||||
|
|
||||||
|
api:
|
||||||
|
serpapi: "${SERPAPI_TOKEN}"
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,8 @@ from config import config
|
||||||
import nlp
|
import nlp
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
completion = nlp.CompletionOpenAI(config["nlp.token"])
|
completion = nlp.CompletionOpenAI(token=config["nlp.token"],
|
||||||
|
serpapi=config["api.serpapi"])
|
||||||
|
|
||||||
async def parser(client, message: str):
|
async def parser(client, message: str):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -21,10 +21,9 @@ async def on_ready():
|
||||||
|
|
||||||
@client.event
|
@client.event
|
||||||
async def on_message(message):
|
async def on_message(message):
|
||||||
print(message)
|
|
||||||
logging.info(f"Message received : {message.content}")
|
logging.info(f"Message received : {message.content}")
|
||||||
if message.author == client.user:
|
#if message.author == client.user:
|
||||||
return
|
# return
|
||||||
if len(message.content) == 0: # Ignore empty messages
|
if len(message.content) == 0: # Ignore empty messages
|
||||||
return
|
return
|
||||||
#if message.content.startswith(config["discord.prefix"]):
|
#if message.content.startswith(config["discord.prefix"]):
|
||||||
|
|
|
@ -1,9 +1,24 @@
|
||||||
|
from langchain.agents import load_tools
|
||||||
|
from langchain.agents import ZeroShotAgent, Tool, AgentExecutor
|
||||||
|
from langchain.memory import ConversationBufferMemory
|
||||||
|
from langchain.agents import initialize_agent
|
||||||
|
from langchain.agents import AgentType
|
||||||
|
from langchain import LLMChain
|
||||||
|
|
||||||
|
|
||||||
class CompletionMeta(type):
|
class CompletionMeta(type):
|
||||||
"""
|
"""
|
||||||
The meta class for completion interface
|
The meta class for completion interface
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
tools = None
|
||||||
|
|
||||||
class Completion():
|
class Completion():
|
||||||
|
def __init__(self):
|
||||||
|
self.prefix = None
|
||||||
|
self.suffix = None
|
||||||
|
self.promptVars = None
|
||||||
|
|
||||||
def complete(self, message: str) -> str:
|
def complete(self, message: str) -> str:
|
||||||
"""
|
"""
|
||||||
Perform a text completion using the language model
|
Perform a text completion using the language model
|
||||||
|
@ -15,3 +30,52 @@ class Completion():
|
||||||
Return the model name
|
Return the model name
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def getChain(self, llm):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def getAgent(self, llm):
|
||||||
|
# Load tools
|
||||||
|
tools = load_tools(["serpapi", "llm-math"], llm=llm)
|
||||||
|
|
||||||
|
# Build a prompt
|
||||||
|
prompt = ZeroShotAgent.create_prompt(
|
||||||
|
tools=tools,
|
||||||
|
prefix=self.getPromptPrefix(),
|
||||||
|
suffix=self.getPromptSuffix(),
|
||||||
|
input_variables=self.getPromptVariables()
|
||||||
|
)
|
||||||
|
memory = ConversationBufferMemory(memory_key="chat_history")
|
||||||
|
|
||||||
|
# Build the LLM Chain
|
||||||
|
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||||
|
agent = ZeroShotAgent(llm_chain=llm_chain,
|
||||||
|
tools=tools,
|
||||||
|
verbose=True)
|
||||||
|
self.agent_chain = AgentExecutor.from_agent_and_tools(agent=agent,
|
||||||
|
tools=tools,
|
||||||
|
verbose=True,
|
||||||
|
memory=memory)
|
||||||
|
|
||||||
|
return self.agent_chain
|
||||||
|
|
||||||
|
def getPromptVariables(self):
|
||||||
|
if self.promptVars is None:
|
||||||
|
self.promptVars = ["input", "chat_history", "agent_scratchpad"]
|
||||||
|
return self.promptVars
|
||||||
|
|
||||||
|
def getPromptPrefix(self):
|
||||||
|
if self.prefix is None:
|
||||||
|
self.prefix = """Have a conversation with a human, answering the following
|
||||||
|
questions as best you can. You have access to the following tools:"""
|
||||||
|
return self.prefix
|
||||||
|
|
||||||
|
def getPromptSuffix(self):
|
||||||
|
if self.suffix is None:
|
||||||
|
self.suffix = """Begin!
|
||||||
|
{chat_history}
|
||||||
|
Question: {input}
|
||||||
|
{agent_scratchpad}"""
|
||||||
|
return self.suffix
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,23 +1,26 @@
|
||||||
import openai
|
|
||||||
from .Completion import Completion
|
from .Completion import Completion
|
||||||
|
from langchain.llms import OpenAI
|
||||||
|
import os
|
||||||
|
|
||||||
class CompletionOpenAI(Completion):
|
class CompletionOpenAI(Completion):
|
||||||
# Constructor
|
# Constructor
|
||||||
def __init__(self, token: str, model: str = "gpt-3.5-turbo"):
|
def __init__(self,
|
||||||
|
token: str,
|
||||||
|
serpapi: str,
|
||||||
|
temperature: float = 0.5):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
os.environ["SERPAPI_API_KEY"] = serpapi
|
||||||
self.token = token
|
self.token = token
|
||||||
self.model = model
|
|
||||||
openai.api_key = self.token
|
self.temperature = temperature
|
||||||
|
self.llm = OpenAI(temperature=self.temperature,
|
||||||
|
openai_api_key=self.token)
|
||||||
|
self.agent = self.getAgent(self.llm)
|
||||||
|
|
||||||
|
|
||||||
def complete(self, message: str) -> str:
|
def complete(self, message: str) -> str:
|
||||||
chat_completion = openai.ChatCompletion.create(
|
return self.agent_chain.run(message)
|
||||||
model=self.model,
|
|
||||||
messages=[
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": message
|
|
||||||
}
|
|
||||||
])
|
|
||||||
return chat_completion.choices[0].message.content
|
|
||||||
|
|
||||||
def getModelName(self) -> str:
|
def getModelName(self) -> str:
|
||||||
return self.model
|
return ""
|
||||||
|
|
|
@ -2,10 +2,33 @@ aiohttp==3.8.4
|
||||||
aiosignal==1.3.1
|
aiosignal==1.3.1
|
||||||
async-timeout==4.0.2
|
async-timeout==4.0.2
|
||||||
attrs==23.1.0
|
attrs==23.1.0
|
||||||
|
certifi==2023.5.7
|
||||||
charset-normalizer==3.1.0
|
charset-normalizer==3.1.0
|
||||||
|
dataclasses-json==0.5.7
|
||||||
discord==2.2.3
|
discord==2.2.3
|
||||||
discord.py==2.2.3
|
discord.py==2.2.3
|
||||||
|
envyaml==1.10.211231
|
||||||
frozenlist==1.3.3
|
frozenlist==1.3.3
|
||||||
|
google-search-results==2.4.2
|
||||||
|
greenlet==2.0.2
|
||||||
idna==3.4
|
idna==3.4
|
||||||
|
langchain==0.0.178
|
||||||
|
marshmallow==3.19.0
|
||||||
|
marshmallow-enum==1.5.1
|
||||||
multidict==6.0.4
|
multidict==6.0.4
|
||||||
|
mypy-extensions==1.0.0
|
||||||
|
numexpr==2.8.4
|
||||||
|
numpy==1.24.3
|
||||||
|
openai==0.27.7
|
||||||
|
openapi-schema-pydantic==1.2.4
|
||||||
|
packaging==23.1
|
||||||
|
pydantic==1.10.8
|
||||||
|
PyYAML==6.0
|
||||||
|
requests==2.31.0
|
||||||
|
SQLAlchemy==2.0.15
|
||||||
|
tenacity==8.2.2
|
||||||
|
tqdm==4.65.0
|
||||||
|
typing-inspect==0.8.0
|
||||||
|
typing_extensions==4.6.0
|
||||||
|
urllib3==2.0.2
|
||||||
yarl==1.9.2
|
yarl==1.9.2
|
||||||
|
|
Loading…
Reference in New Issue