Add basic chain with memory
This commit is contained in:
parent
fe5f5e2d3a
commit
fd29965e5b
@ -8,3 +8,6 @@ nlp:
|
||||
token: "${OPENAI_TOKEN}"
|
||||
model: "gpt-3.5-turbo"
|
||||
|
||||
api:
|
||||
serpapi: "${SERPAPI_TOKEN}"
|
||||
|
||||
|
@ -4,7 +4,8 @@ from config import config
|
||||
import nlp
|
||||
import logging
|
||||
|
||||
completion = nlp.CompletionOpenAI(config["nlp.token"])
|
||||
completion = nlp.CompletionOpenAI(token=config["nlp.token"],
|
||||
serpapi=config["api.serpapi"])
|
||||
|
||||
async def parser(client, message: str):
|
||||
"""
|
||||
|
@ -21,10 +21,9 @@ async def on_ready():
|
||||
|
||||
@client.event
|
||||
async def on_message(message):
|
||||
print(message)
|
||||
logging.info(f"Message received : {message.content}")
|
||||
if message.author == client.user:
|
||||
return
|
||||
#if message.author == client.user:
|
||||
# return
|
||||
if len(message.content) == 0: # Ignore empty messages
|
||||
return
|
||||
#if message.content.startswith(config["discord.prefix"]):
|
||||
|
@ -1,9 +1,24 @@
|
||||
from langchain.agents import load_tools
|
||||
from langchain.agents import ZeroShotAgent, Tool, AgentExecutor
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
from langchain.agents import initialize_agent
|
||||
from langchain.agents import AgentType
|
||||
from langchain import LLMChain
|
||||
|
||||
|
||||
class CompletionMeta(type):
|
||||
"""
|
||||
The meta class for completion interface
|
||||
"""
|
||||
|
||||
tools = None
|
||||
|
||||
class Completion():
|
||||
def __init__(self):
|
||||
self.prefix = None
|
||||
self.suffix = None
|
||||
self.promptVars = None
|
||||
|
||||
def complete(self, message: str) -> str:
|
||||
"""
|
||||
Perform a text completion using the language model
|
||||
@ -15,3 +30,52 @@ class Completion():
|
||||
Return the model name
|
||||
"""
|
||||
pass
|
||||
|
||||
def getChain(self, llm):
|
||||
pass
|
||||
|
||||
def getAgent(self, llm):
|
||||
# Load tools
|
||||
tools = load_tools(["serpapi", "llm-math"], llm=llm)
|
||||
|
||||
# Build a prompt
|
||||
prompt = ZeroShotAgent.create_prompt(
|
||||
tools=tools,
|
||||
prefix=self.getPromptPrefix(),
|
||||
suffix=self.getPromptSuffix(),
|
||||
input_variables=self.getPromptVariables()
|
||||
)
|
||||
memory = ConversationBufferMemory(memory_key="chat_history")
|
||||
|
||||
# Build the LLM Chain
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
agent = ZeroShotAgent(llm_chain=llm_chain,
|
||||
tools=tools,
|
||||
verbose=True)
|
||||
self.agent_chain = AgentExecutor.from_agent_and_tools(agent=agent,
|
||||
tools=tools,
|
||||
verbose=True,
|
||||
memory=memory)
|
||||
|
||||
return self.agent_chain
|
||||
|
||||
def getPromptVariables(self):
|
||||
if self.promptVars is None:
|
||||
self.promptVars = ["input", "chat_history", "agent_scratchpad"]
|
||||
return self.promptVars
|
||||
|
||||
def getPromptPrefix(self):
|
||||
if self.prefix is None:
|
||||
self.prefix = """Have a conversation with a human, answering the following
|
||||
questions as best you can. You have access to the following tools:"""
|
||||
return self.prefix
|
||||
|
||||
def getPromptSuffix(self):
|
||||
if self.suffix is None:
|
||||
self.suffix = """Begin!
|
||||
{chat_history}
|
||||
Question: {input}
|
||||
{agent_scratchpad}"""
|
||||
return self.suffix
|
||||
|
||||
|
||||
|
@ -1,23 +1,26 @@
|
||||
import openai
|
||||
from .Completion import Completion
|
||||
from langchain.llms import OpenAI
|
||||
import os
|
||||
|
||||
class CompletionOpenAI(Completion):
|
||||
# Constructor
|
||||
def __init__(self, token: str, model: str = "gpt-3.5-turbo"):
|
||||
def __init__(self,
|
||||
token: str,
|
||||
serpapi: str,
|
||||
temperature: float = 0.5):
|
||||
super().__init__()
|
||||
|
||||
os.environ["SERPAPI_API_KEY"] = serpapi
|
||||
self.token = token
|
||||
self.model = model
|
||||
openai.api_key = self.token
|
||||
|
||||
self.temperature = temperature
|
||||
self.llm = OpenAI(temperature=self.temperature,
|
||||
openai_api_key=self.token)
|
||||
self.agent = self.getAgent(self.llm)
|
||||
|
||||
|
||||
def complete(self, message: str) -> str:
|
||||
chat_completion = openai.ChatCompletion.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": message
|
||||
}
|
||||
])
|
||||
return chat_completion.choices[0].message.content
|
||||
return self.agent_chain.run(message)
|
||||
|
||||
def getModelName(self) -> str:
|
||||
return self.model
|
||||
return ""
|
||||
|
@ -2,10 +2,33 @@ aiohttp==3.8.4
|
||||
aiosignal==1.3.1
|
||||
async-timeout==4.0.2
|
||||
attrs==23.1.0
|
||||
certifi==2023.5.7
|
||||
charset-normalizer==3.1.0
|
||||
dataclasses-json==0.5.7
|
||||
discord==2.2.3
|
||||
discord.py==2.2.3
|
||||
envyaml==1.10.211231
|
||||
frozenlist==1.3.3
|
||||
google-search-results==2.4.2
|
||||
greenlet==2.0.2
|
||||
idna==3.4
|
||||
langchain==0.0.178
|
||||
marshmallow==3.19.0
|
||||
marshmallow-enum==1.5.1
|
||||
multidict==6.0.4
|
||||
mypy-extensions==1.0.0
|
||||
numexpr==2.8.4
|
||||
numpy==1.24.3
|
||||
openai==0.27.7
|
||||
openapi-schema-pydantic==1.2.4
|
||||
packaging==23.1
|
||||
pydantic==1.10.8
|
||||
PyYAML==6.0
|
||||
requests==2.31.0
|
||||
SQLAlchemy==2.0.15
|
||||
tenacity==8.2.2
|
||||
tqdm==4.65.0
|
||||
typing-inspect==0.8.0
|
||||
typing_extensions==4.6.0
|
||||
urllib3==2.0.2
|
||||
yarl==1.9.2
|
||||
|
Loading…
Reference in New Issue
Block a user