Add add document endpoint

This commit is contained in:
Tyler Perkins 2023-06-12 22:30:39 -04:00
parent f9f1982442
commit 084e1cd383
12 changed files with 353 additions and 32 deletions

View File

@ -2,6 +2,8 @@ import envyaml
import os import os
import logging import logging
from logging import Formatter, StreamHandler from logging import Formatter, StreamHandler
import weaviate
import json
config = envyaml.EnvYAML(os.environ.get('CONFIG_PATH', 'config.yaml')) config = envyaml.EnvYAML(os.environ.get('CONFIG_PATH', 'config.yaml'))
@ -35,3 +37,15 @@ def initLogger() -> None:
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logging.warning("Invalid log level. Using INFO as default") logging.warning("Invalid log level. Using INFO as default")
def initEnvironment() -> None:
os.environ["OPENAI_PROXY"] = config["app.openai.url"]
os.environ["OPENAI_API_KEY"] = config["app.openai.api_key"]
def initWeviate() -> None:
logging.debug("Initializing Weaviate")
client = weaviate.Client(config["app.weaviate.url"])
with open(config["app.weaviate.schema-path"]) as file:
schema = json.load(file)
if not client.schema.contains(schema):
logging.debug("Creating Weaviate schema")
client.schema.create(schema)

View File

@ -1,6 +1,18 @@
weaviate:
url: "localhost:"
app: app:
log: log:
level: "" level: "debug"
weaviate:
url: "http://localhost:12345"
text-field: "content"
index-name: "knowledge"
schema-path: "schema.json"
openai:
url: "http://192.168.1.104:11111"
api_key: "sk-"
document:
# How many characters should each chunked document be split into?
split_chunk_size: 500
# How much overlap should each chunk have with its neighbor
split_chunk_overlap: 20
# What model in our OpenAPI api should we use?
embeddings_model: "text-embedding"

View File

@ -0,0 +1,12 @@
from enum import Enum
class DocumentType(str, Enum):
"""
Enumerated type for document types that we support
"""
markdown = "md"
#html = "html"
#pdf = "pdf"
#epub = "epub"
#odt = "odt"
#docx = "docx"

View File

@ -1,3 +1,71 @@
from langchain.document_loaders import UnstructuredMarkdownLoader
from langchain.schema import Document
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Weaviate
from .documentTypes import DocumentType
from config import config
import logging
import weaviate
from typing import Iterable, List
from langchain.text_splitter import NLTKTextSplitter
from internal.weaviate.weaviate import WeaviateClient
from langchain.text_splitter import SpacyTextSplitter # Globals
# =======
# The text splitter
textSplitter = NLTKTextSplitter(chunk_size=config["app.document.split_chunk_size"],
chunk_overlap=config["app.document.split_chunk_overlap"])
"""
The text splitter client
"""
db = None
def loadDocumentIntoWeaviate(documentType: DocumentType,
document_name : str,
path: str) -> None:
"""
Loads a document into Weaviate.
"""
global db
documents = None
if documentType == DocumentType.markdown:
loader = UnstructuredMarkdownLoader(path)
documents = loader.load()
else:
raise Exception("Document type not supported.")
# Split up the document
texts = splitDocument(documents)
for text in texts:
text.metadata["document_name"] = document_name
if db is None:
print(config["app.weaviate.url"])
client = weaviate.Client(config["app.weaviate.url"])
db = WeaviateClient(client,
text_key=config["app.weaviate.text-field"],
index_name=config["app.weaviate.index-name"])
db.addDocuments(documents=texts)
logging.info(f"Loaded document {len(documents)} into Weaviate")
def splitDocument(document: Document | Iterable[Document]) -> List[Document]:
"""
Splits a document into multiple documents using spaCy
"""
if document is None:
raise Exception("Document is None")
global textSplitter
return textSplitter.split_documents(document)

View File

@ -0,0 +1,11 @@
from abc import ABC, abstractmethod
from typing import List
class EmbeddingProvider(ABC):
@abstractmethod
def getEmbedding(self, word: str) -> List[float]:
"""
Returns the embedding for the given word
"""
pass

View File

@ -0,0 +1,48 @@
from .EmbeddingProvider import EmbeddingProvider
from typing import List
import openai
from config import config
class OpenAIEmbeddingProvider(EmbeddingProvider):
def __init__(self):
super().__init__()
self.openai_api_key = config["app.openai.api_key"]
self.openai_url = config["app.openai.url"]
self.model = config["app.document.embeddings_model"]
def getEmbedding(self, text: str) -> List[float]:
"""
Returns the embedding for the given string
"""
openai.api_key = self.openai_api_key
openai.api_base = self.openai_url
return openai.Embedding.create(input = [text],
model=self.model)['data'][0]['embedding']
@property
def OPENAI_API_KEY(self):
return self.openai_api_key
@OPENAI_API_KEY.setter
def OPENAI_API_KEY(self, value):
self.openai_api_key = value
@property
def OPENAI_URL(self):
return self.openai_url
@OPENAI_URL.setter
def OPENAI_URL(self, value):
self.openai_url = value
@property
def MODEL(self):
return self.model
@MODEL.setter
def MODEL(self, value):
self.model = value

View File

View File

@ -0,0 +1,121 @@
import weaviate
import json
import logging
from typing import Iterable, List
from internal.embeddings.OpenAIEmbeddingProvider import OpenAIEmbeddingProvider
from internal.embeddings.EmbeddingProvider import EmbeddingProvider
from langchain.schema import Document
class WeaviateClient:
# Constructor
def __init__(self, client: weaviate.Client,
text_key: str = "content",
index_name: str = "documents",
embeddingProvider: EmbeddingProvider = OpenAIEmbeddingProvider(),
):
self.client = client
self.embeddingProvider = embeddingProvider
self.text_key = text_key
self.index_name = index_name
@classmethod
def fromUrl(cls, endpoint: str):
"""
Creates a WeaviateClient from an endpoint
"""
client = weaviate.Client(endpoint)
return cls(client)
def addDocuments(self, documents: Iterable[Document]) -> None:
"""
Adds a list of documents to the store
"""
if self.embeddingProvider is None:
raise Exception("No embedding provider set")
with self.client.batch as batch:
for i, text in enumerate(documents):
data_properties = {self.text_key: text.page_content}
if text.metadata is not None:
for key, val in text.metadata.items():
data_properties[key] = val
vector = self.embeddingProvider.getEmbedding(text.page_content)
batch.add_data_object(
data_object=data_properties,
class_name=self.index_name,
vector=vector
)
def addDocument(self, document: Document) -> None:
"""
Adds a given document to the store
"""
self.addDocuments([document])
def similaritySearch(self, query: str, k: int = 10) -> List[str]:
"""
Searches for similar documents
Args:
query: Text to lookup
k: Number of results to return, default 10
"""
pass
def removeDocument(self, document_id: str) -> None:
"""
Removes a document from the store
Args:
document: Document to remove
"""
pass
def getDocument(self, document_id: str) -> str:
"""
Returns a document from the store
Args:
document: Document to return
"""
pass
@property
def ENDPOINT(self):
return self.endpoint
@ENDPOINT.setter
def ENDPOINT(self, endpoint: str):
self.endpoint = endpoint
@property
def EMBEDDING_PROVIDER(self):
return self.embeddingProvider
@EMBEDDING_PROVIDER.setter
def EMBEDDING_PROVIDER(self, embeddingProvider: EmbeddingProvider):
self.embeddingProvider = embeddingProvider
@property
def TEXT_KEY(self):
return self.text_key
@TEXT_KEY.setter
def TEXT_KEY(self, text_key: str):
self.text_key = text_key
@property
def INDEX_NAME(self):
return self.index_name
@INDEX_NAME.setter
def INDEX_NAME(self, index_name: str):
self.index_name = index_name

View File

@ -4,18 +4,19 @@ from fastapi.responses import RedirectResponse
from routers import question, documents from routers import question, documents
import logging import logging
from config import config, initLogger from config import config, initLogger, initEnvironment, initWeviate
initEnvironment()
initLogger() initLogger()
# Init weaviate, if not done already
initWeviate()
app = FastAPI() app = FastAPI()
app.include_router(question.router) app.include_router(question.router)
app.include_router(documents.router) app.include_router(documents.router)
logging.warn("Test message")
@app.get("/") @app.get("/")
async def root(): async def root():
""" """

View File

@ -1,6 +1,11 @@
from fastapi import APIRouter from fastapi import APIRouter, UploadFile
from pydantic import BaseModel from pydantic import BaseModel
from enum import Enum from enum import Enum
import tempfile
import logging
from internal.documentTypes import DocumentType
from internal import documents
router = APIRouter( router = APIRouter(
prefix="/documents", prefix="/documents",
@ -8,21 +13,6 @@ router = APIRouter(
responses={404: {"description": "Not found"}}, responses={404: {"description": "Not found"}},
) )
class Document(BaseModel):
id: int
title: str
content: str
# Document Type enum
class DocumentType(str, Enum):
markdown = "md"
#html = "html"
#pdf = "pdf"
epub = "epub"
odt = "odt"
docx = "docx"
@router.get("/{document_type}") @router.get("/{document_type}")
async def read_documents(document_type: DocumentType): async def read_documents(document_type: DocumentType):
""" """
@ -31,30 +21,54 @@ async def read_documents(document_type: DocumentType):
pass pass
@router.get("/{document_id}") @router.get("/{document_id}")
async def read_document(document_id: int): async def read_document(document_id: str):
""" """
Get a specific document Get a specific document
""" """
pass pass
@router.post("/") @router.post("/")
async def create_document(document: Document): async def create_document(document_type: DocumentType,
file: UploadFile):
""" """
Create a new document Create a new document
""" """
pass tmp = tempfile.NamedTemporaryFile(delete=True)
document_id = file.filename
try:
logging.info(f"Uploaded file {file.filename} to {tmp.name}")
# Write the file to a temporary file
tmp.write(await file.read())
tmp.flush()
# Load the document
documents.loadDocumentIntoWeaviate(document_type, document_id, tmp.name)
finally:
tmp.close()
return {"document_id": document_id}
@router.put("/{document_id}") @router.put("/{document_id}")
async def update_document(document_id: int, document: Document): async def update_document(document_id: str):
""" """
Update a document Update a document
""" """
pass
@router.delete("/{document_id}") @router.delete("/{document_id}")
async def delete_document(document_id: int): async def delete_document(document_id: str):
""" """
Delete a document Delete a document
""" """
pass pass
@router.get("/find")
async def find_document(query: str):
"""
Finds a document with content similar to given query
"""
pass

View File

@ -22,4 +22,3 @@ async def ask_question(question: Question, conversation_id: str):
@router.get("/{conversation_id}}") @router.get("/{conversation_id}}")
async def get_question_history(conversation_id: str): async def get_question_history(conversation_id: str):
return {"message": f"Hello question {conversation_id}!"} return {"message": f"Hello question {conversation_id}!"}

21
src/schema.json Normal file
View File

@ -0,0 +1,21 @@
{
"classes": [
{
"class": "knowledge",
"description": "Knowledge for the language models",
"vectorizer": "none",
"properties": [
{
"name": "content",
"description": "The content of the document",
"dataType": ["text"]
},
{
"name": "document_id",
"description": "The id of the document, user facing",
"dataType": ["text"]
}
]
}
]
}