From 084e1cd383cb8802d801cbec5b786c25d459e316 Mon Sep 17 00:00:00 2001 From: Tyler Perkins Date: Mon, 12 Jun 2023 22:30:39 -0400 Subject: [PATCH] Add add document endpoint --- src/config.py | 14 ++ src/config.yaml | 20 ++- src/internal/documentTypes.py | 12 ++ src/internal/documents.py | 70 +++++++++- src/internal/embeddings/EmbeddingProvider.py | 11 ++ .../embeddings/OpenAIEmbeddingProvider.py | 48 +++++++ src/internal/embeddings/__init__.py | 0 src/internal/weaviate/weaviate.py | 121 ++++++++++++++++++ src/main.py | 9 +- src/routers/documents.py | 58 +++++---- src/routers/question.py | 1 - src/schema.json | 21 +++ 12 files changed, 353 insertions(+), 32 deletions(-) create mode 100644 src/internal/documentTypes.py create mode 100644 src/internal/embeddings/EmbeddingProvider.py create mode 100644 src/internal/embeddings/OpenAIEmbeddingProvider.py create mode 100644 src/internal/embeddings/__init__.py create mode 100644 src/internal/weaviate/weaviate.py create mode 100644 src/schema.json diff --git a/src/config.py b/src/config.py index 5df9ac2..2f8f39b 100644 --- a/src/config.py +++ b/src/config.py @@ -2,6 +2,8 @@ import envyaml import os import logging from logging import Formatter, StreamHandler +import weaviate +import json config = envyaml.EnvYAML(os.environ.get('CONFIG_PATH', 'config.yaml')) @@ -35,3 +37,15 @@ def initLogger() -> None: logging.basicConfig(level=logging.INFO) logging.warning("Invalid log level. Using INFO as default") +def initEnvironment() -> None: + os.environ["OPENAI_PROXY"] = config["app.openai.url"] + os.environ["OPENAI_API_KEY"] = config["app.openai.api_key"] + +def initWeviate() -> None: + logging.debug("Initializing Weaviate") + client = weaviate.Client(config["app.weaviate.url"]) + with open(config["app.weaviate.schema-path"]) as file: + schema = json.load(file) + if not client.schema.contains(schema): + logging.debug("Creating Weaviate schema") + client.schema.create(schema) diff --git a/src/config.yaml b/src/config.yaml index 6dd43b1..d61bbf4 100644 --- a/src/config.yaml +++ b/src/config.yaml @@ -1,6 +1,18 @@ -weaviate: - url: "localhost:" - app: log: - level: "" + level: "debug" + weaviate: + url: "http://localhost:12345" + text-field: "content" + index-name: "knowledge" + schema-path: "schema.json" + openai: + url: "http://192.168.1.104:11111" + api_key: "sk-" + document: + # How many characters should each chunked document be split into? + split_chunk_size: 500 + # How much overlap should each chunk have with its neighbor + split_chunk_overlap: 20 + # What model in our OpenAPI api should we use? + embeddings_model: "text-embedding" diff --git a/src/internal/documentTypes.py b/src/internal/documentTypes.py new file mode 100644 index 0000000..a4c83aa --- /dev/null +++ b/src/internal/documentTypes.py @@ -0,0 +1,12 @@ +from enum import Enum + +class DocumentType(str, Enum): + """ + Enumerated type for document types that we support + """ + markdown = "md" + #html = "html" + #pdf = "pdf" + #epub = "epub" + #odt = "odt" + #docx = "docx" diff --git a/src/internal/documents.py b/src/internal/documents.py index a78fb7e..0324f0c 100644 --- a/src/internal/documents.py +++ b/src/internal/documents.py @@ -1,3 +1,71 @@ +from langchain.document_loaders import UnstructuredMarkdownLoader +from langchain.schema import Document +from langchain.embeddings import OpenAIEmbeddings +from langchain.vectorstores import Weaviate +from .documentTypes import DocumentType +from config import config +import logging +import weaviate +from typing import Iterable, List +from langchain.text_splitter import NLTKTextSplitter +from internal.weaviate.weaviate import WeaviateClient -from langchain.text_splitter import SpacyTextSplitter +# Globals +# ======= +# The text splitter + +textSplitter = NLTKTextSplitter(chunk_size=config["app.document.split_chunk_size"], + chunk_overlap=config["app.document.split_chunk_overlap"]) +""" +The text splitter client +""" + +db = None + + + +def loadDocumentIntoWeaviate(documentType: DocumentType, + document_name : str, + path: str) -> None: + """ + Loads a document into Weaviate. + """ + global db + + documents = None + + if documentType == DocumentType.markdown: + loader = UnstructuredMarkdownLoader(path) + documents = loader.load() + else: + raise Exception("Document type not supported.") + + # Split up the document + texts = splitDocument(documents) + + for text in texts: + text.metadata["document_name"] = document_name + + if db is None: + print(config["app.weaviate.url"]) + client = weaviate.Client(config["app.weaviate.url"]) + db = WeaviateClient(client, + text_key=config["app.weaviate.text-field"], + index_name=config["app.weaviate.index-name"]) + + db.addDocuments(documents=texts) + + logging.info(f"Loaded document {len(documents)} into Weaviate") + + +def splitDocument(document: Document | Iterable[Document]) -> List[Document]: + """ + Splits a document into multiple documents using spaCy + """ + if document is None: + raise Exception("Document is None") + + global textSplitter + + return textSplitter.split_documents(document) diff --git a/src/internal/embeddings/EmbeddingProvider.py b/src/internal/embeddings/EmbeddingProvider.py new file mode 100644 index 0000000..4c456f9 --- /dev/null +++ b/src/internal/embeddings/EmbeddingProvider.py @@ -0,0 +1,11 @@ +from abc import ABC, abstractmethod +from typing import List + +class EmbeddingProvider(ABC): + @abstractmethod + def getEmbedding(self, word: str) -> List[float]: + """ + Returns the embedding for the given word + """ + pass + diff --git a/src/internal/embeddings/OpenAIEmbeddingProvider.py b/src/internal/embeddings/OpenAIEmbeddingProvider.py new file mode 100644 index 0000000..60cedce --- /dev/null +++ b/src/internal/embeddings/OpenAIEmbeddingProvider.py @@ -0,0 +1,48 @@ +from .EmbeddingProvider import EmbeddingProvider +from typing import List +import openai +from config import config + +class OpenAIEmbeddingProvider(EmbeddingProvider): + + def __init__(self): + super().__init__() + self.openai_api_key = config["app.openai.api_key"] + self.openai_url = config["app.openai.url"] + self.model = config["app.document.embeddings_model"] + + def getEmbedding(self, text: str) -> List[float]: + """ + Returns the embedding for the given string + """ + openai.api_key = self.openai_api_key + openai.api_base = self.openai_url + return openai.Embedding.create(input = [text], + model=self.model)['data'][0]['embedding'] + + @property + def OPENAI_API_KEY(self): + return self.openai_api_key + + @OPENAI_API_KEY.setter + def OPENAI_API_KEY(self, value): + self.openai_api_key = value + + @property + def OPENAI_URL(self): + return self.openai_url + + @OPENAI_URL.setter + def OPENAI_URL(self, value): + self.openai_url = value + + @property + def MODEL(self): + return self.model + + @MODEL.setter + def MODEL(self, value): + self.model = value + + + diff --git a/src/internal/embeddings/__init__.py b/src/internal/embeddings/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/internal/weaviate/weaviate.py b/src/internal/weaviate/weaviate.py new file mode 100644 index 0000000..4bc1749 --- /dev/null +++ b/src/internal/weaviate/weaviate.py @@ -0,0 +1,121 @@ +import weaviate +import json +import logging +from typing import Iterable, List +from internal.embeddings.OpenAIEmbeddingProvider import OpenAIEmbeddingProvider +from internal.embeddings.EmbeddingProvider import EmbeddingProvider +from langchain.schema import Document + +class WeaviateClient: + # Constructor + def __init__(self, client: weaviate.Client, + text_key: str = "content", + index_name: str = "documents", + embeddingProvider: EmbeddingProvider = OpenAIEmbeddingProvider(), + ): + self.client = client + self.embeddingProvider = embeddingProvider + self.text_key = text_key + self.index_name = index_name + + @classmethod + def fromUrl(cls, endpoint: str): + """ + Creates a WeaviateClient from an endpoint + """ + client = weaviate.Client(endpoint) + return cls(client) + + def addDocuments(self, documents: Iterable[Document]) -> None: + """ + Adds a list of documents to the store + """ + if self.embeddingProvider is None: + raise Exception("No embedding provider set") + + with self.client.batch as batch: + for i, text in enumerate(documents): + data_properties = {self.text_key: text.page_content} + if text.metadata is not None: + for key, val in text.metadata.items(): + data_properties[key] = val + + vector = self.embeddingProvider.getEmbedding(text.page_content) + batch.add_data_object( + data_object=data_properties, + class_name=self.index_name, + vector=vector + ) + + + + def addDocument(self, document: Document) -> None: + """ + Adds a given document to the store + """ + self.addDocuments([document]) + + def similaritySearch(self, query: str, k: int = 10) -> List[str]: + """ + Searches for similar documents + + Args: + query: Text to lookup + k: Number of results to return, default 10 + """ + pass + + def removeDocument(self, document_id: str) -> None: + """ + Removes a document from the store + + Args: + document: Document to remove + """ + pass + + def getDocument(self, document_id: str) -> str: + """ + Returns a document from the store + + Args: + document: Document to return + """ + pass + + + @property + def ENDPOINT(self): + return self.endpoint + + @ENDPOINT.setter + def ENDPOINT(self, endpoint: str): + self.endpoint = endpoint + + @property + def EMBEDDING_PROVIDER(self): + return self.embeddingProvider + + @EMBEDDING_PROVIDER.setter + def EMBEDDING_PROVIDER(self, embeddingProvider: EmbeddingProvider): + self.embeddingProvider = embeddingProvider + + @property + def TEXT_KEY(self): + return self.text_key + + @TEXT_KEY.setter + def TEXT_KEY(self, text_key: str): + self.text_key = text_key + + @property + def INDEX_NAME(self): + return self.index_name + + @INDEX_NAME.setter + def INDEX_NAME(self, index_name: str): + self.index_name = index_name + + + + diff --git a/src/main.py b/src/main.py index c6a1458..ee20985 100644 --- a/src/main.py +++ b/src/main.py @@ -4,18 +4,19 @@ from fastapi.responses import RedirectResponse from routers import question, documents import logging -from config import config, initLogger +from config import config, initLogger, initEnvironment, initWeviate +initEnvironment() initLogger() +# Init weaviate, if not done already +initWeviate() + app = FastAPI() app.include_router(question.router) app.include_router(documents.router) -logging.warn("Test message") - - @app.get("/") async def root(): """ diff --git a/src/routers/documents.py b/src/routers/documents.py index 57ef928..2d386a6 100644 --- a/src/routers/documents.py +++ b/src/routers/documents.py @@ -1,6 +1,11 @@ -from fastapi import APIRouter +from fastapi import APIRouter, UploadFile from pydantic import BaseModel from enum import Enum +import tempfile +import logging +from internal.documentTypes import DocumentType + +from internal import documents router = APIRouter( prefix="/documents", @@ -8,21 +13,6 @@ router = APIRouter( responses={404: {"description": "Not found"}}, ) -class Document(BaseModel): - id: int - title: str - content: str - -# Document Type enum -class DocumentType(str, Enum): - markdown = "md" - #html = "html" - #pdf = "pdf" - epub = "epub" - odt = "odt" - docx = "docx" - - @router.get("/{document_type}") async def read_documents(document_type: DocumentType): """ @@ -31,30 +21,54 @@ async def read_documents(document_type: DocumentType): pass @router.get("/{document_id}") -async def read_document(document_id: int): +async def read_document(document_id: str): """ Get a specific document """ pass @router.post("/") -async def create_document(document: Document): +async def create_document(document_type: DocumentType, + file: UploadFile): """ Create a new document """ - pass + tmp = tempfile.NamedTemporaryFile(delete=True) + document_id = file.filename + + try: + logging.info(f"Uploaded file {file.filename} to {tmp.name}") + + # Write the file to a temporary file + tmp.write(await file.read()) + tmp.flush() + + # Load the document + documents.loadDocumentIntoWeaviate(document_type, document_id, tmp.name) + + finally: + tmp.close() + + return {"document_id": document_id} @router.put("/{document_id}") -async def update_document(document_id: int, document: Document): +async def update_document(document_id: str): """ Update a document """ - pass @router.delete("/{document_id}") -async def delete_document(document_id: int): +async def delete_document(document_id: str): """ Delete a document """ pass +@router.get("/find") +async def find_document(query: str): + """ + Finds a document with content similar to given query + """ + pass + + diff --git a/src/routers/question.py b/src/routers/question.py index e506754..577f9bb 100644 --- a/src/routers/question.py +++ b/src/routers/question.py @@ -22,4 +22,3 @@ async def ask_question(question: Question, conversation_id: str): @router.get("/{conversation_id}}") async def get_question_history(conversation_id: str): return {"message": f"Hello question {conversation_id}!"} - diff --git a/src/schema.json b/src/schema.json new file mode 100644 index 0000000..df1a5e1 --- /dev/null +++ b/src/schema.json @@ -0,0 +1,21 @@ +{ + "classes": [ + { + "class": "knowledge", + "description": "Knowledge for the language models", + "vectorizer": "none", + "properties": [ + { + "name": "content", + "description": "The content of the document", + "dataType": ["text"] + }, + { + "name": "document_id", + "description": "The id of the document, user facing", + "dataType": ["text"] + } + ] + } + ] +}