From 70a12c5820a982f435cbdd56c2eb63ae1b72f3c7 Mon Sep 17 00:00:00 2001 From: Tyler Perkins Date: Mon, 12 Jun 2023 23:42:56 -0400 Subject: [PATCH] Add query function --- docker-compose.yml | 23 +++++++++ src/config.yaml | 2 +- src/internal/documents.py | 84 ++++++++++++++++++++----------- src/internal/weaviate/weaviate.py | 33 ++++++++++-- src/routers/documents.py | 61 ++++++++++++++-------- 5 files changed, 149 insertions(+), 54 deletions(-) create mode 100644 docker-compose.yml diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..ed6c9fb --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,23 @@ +--- +version: '3.4' +services: + weaviate: + command: + - --host + - 0.0.0.0 + - --port + - '12345' + - --scheme + - http + image: semitechnologies/weaviate:1.19.7 + ports: + - 12345:12345 + restart: on-failure:0 + environment: + QUERY_DEFAULTS_LIMIT: 25 + AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED: 'true' + PERSISTENCE_DATA_PATH: '/var/lib/weaviate' + DEFAULT_VECTORIZER_MODULE: 'none' + ENABLE_MODULES: '' + CLUSTER_HOSTNAME: 'node1' +... diff --git a/src/config.yaml b/src/config.yaml index d61bbf4..9a0a41d 100644 --- a/src/config.yaml +++ b/src/config.yaml @@ -4,7 +4,7 @@ app: weaviate: url: "http://localhost:12345" text-field: "content" - index-name: "knowledge" + index-name: "Knowledge" schema-path: "schema.json" openai: url: "http://192.168.1.104:11111" diff --git a/src/internal/documents.py b/src/internal/documents.py index 0324f0c..125a177 100644 --- a/src/internal/documents.py +++ b/src/internal/documents.py @@ -22,7 +22,38 @@ The text splitter client db = None +def loadDocumentsIntoWeaviate(documentType: DocumentType, + documents: dict) -> None: + """ + Loads a document into Weaviate. + Expects documents to be of formate {name -> path} + """ + texts = [] + + for name, path in documents.items(): + documents = None + + if documentType == DocumentType.markdown: + loader = UnstructuredMarkdownLoader(path) + documents = loader.load() + else: + raise Exception("Document type not supported.") + + # Split up the document + newDocuments = splitDocuments(documents) + for doc in newDocuments: + doc.metadata["document_name"] = name + + texts.extend(newDocuments) + + db = getDatabase() + + logging.info(f"Parsed {len(documents)} documents") + + db.addDocuments(documents=texts) + + logging.info(f"Loaded document {len(documents)} into Weaviate") def loadDocumentIntoWeaviate(documentType: DocumentType, document_name : str, @@ -30,37 +61,18 @@ def loadDocumentIntoWeaviate(documentType: DocumentType, """ Loads a document into Weaviate. """ - global db + loadDocumentsIntoWeaviate(documentType, {document_name: path}) - documents = None - - if documentType == DocumentType.markdown: - loader = UnstructuredMarkdownLoader(path) - documents = loader.load() - else: - raise Exception("Document type not supported.") - - # Split up the document - texts = splitDocument(documents) - - for text in texts: - text.metadata["document_name"] = document_name - - if db is None: - print(config["app.weaviate.url"]) - client = weaviate.Client(config["app.weaviate.url"]) - db = WeaviateClient(client, - text_key=config["app.weaviate.text-field"], - index_name=config["app.weaviate.index-name"]) - - db.addDocuments(documents=texts) - - logging.info(f"Loaded document {len(documents)} into Weaviate") - - -def splitDocument(document: Document | Iterable[Document]) -> List[Document]: +def findSimilarDocuments(query: str, limit : int = 10) -> List[Document]: """ - Splits a document into multiple documents using spaCy + Finds similar documents to the query + """ + db = getDatabase() + return db.similaritySearch(query=query, k=limit) + +def splitDocuments(document: Iterable[Document]) -> List[Document]: + """ + Splits a document into multiple documents """ if document is None: raise Exception("Document is None") @@ -69,3 +81,17 @@ def splitDocument(document: Document | Iterable[Document]) -> List[Document]: return textSplitter.split_documents(document) +def getDatabase() -> WeaviateClient: + """ + Get a weaviate client instance + """ + global db + + if db is None: + logging.debug(config["app.weaviate.url"]) + client = weaviate.Client(config["app.weaviate.url"]) + db = WeaviateClient(client, + text_key = config["app.weaviate.text-field"], + index_name = config["app.weaviate.index-name"]) + + return db diff --git a/src/internal/weaviate/weaviate.py b/src/internal/weaviate/weaviate.py index 4bc1749..c95a49e 100644 --- a/src/internal/weaviate/weaviate.py +++ b/src/internal/weaviate/weaviate.py @@ -47,15 +47,13 @@ class WeaviateClient: vector=vector ) - - def addDocument(self, document: Document) -> None: """ Adds a given document to the store """ self.addDocuments([document]) - def similaritySearch(self, query: str, k: int = 10) -> List[str]: + def similaritySearch(self, query: str, k: int = 10) -> List[Document]: """ Searches for similar documents @@ -63,7 +61,34 @@ class WeaviateClient: query: Text to lookup k: Number of results to return, default 10 """ - pass + if self.embeddingProvider is None: + raise Exception("No embedding provider set") + vector = self.embeddingProvider.getEmbedding(query) + return self.similaritySearchByVector(vector, k) + + def similaritySearchByVector(self, vector: List[float], k: int = 10) -> List[Document]: + """ + Searches for similar documents + + Args: + vector: Vector to lookup + k: Number of results to return, default 10 + """ + vectorQuery = { "vector": vector } + query_obj = self.client.query.get(self.index_name, ["content", "document_name"]) + result = query_obj.with_near_vector(vectorQuery).with_limit(k).do() + + if "errors" in result: + raise Exception(result["errors"]) + results = [] + + print(result) + + for res in result["data"]["Get"][self.index_name]: + text = res.pop(self.text_key) + results.append(Document(page_content=text, metadata=res)) + + return results def removeDocument(self, document_id: str) -> None: """ diff --git a/src/routers/documents.py b/src/routers/documents.py index 2d386a6..ffca9c7 100644 --- a/src/routers/documents.py +++ b/src/routers/documents.py @@ -1,4 +1,4 @@ -from fastapi import APIRouter, UploadFile +from fastapi import APIRouter, UploadFile, responses from pydantic import BaseModel from enum import Enum import tempfile @@ -7,55 +7,63 @@ from internal.documentTypes import DocumentType from internal import documents +# set logger to debug +logging.basicConfig(level=logging.DEBUG) + router = APIRouter( prefix="/documents", tags=["documents"], responses={404: {"description": "Not found"}}, ) -@router.get("/{document_type}") -async def read_documents(document_type: DocumentType): +@router.get("/") +async def read_documents(): """ Get all documents """ pass -@router.get("/{document_id}") -async def read_document(document_id: str): - """ - Get a specific document - """ - pass @router.post("/") async def create_document(document_type: DocumentType, - file: UploadFile): + files: list[UploadFile]): """ Create a new document """ - tmp = tempfile.NamedTemporaryFile(delete=True) - document_id = file.filename - try: - logging.info(f"Uploaded file {file.filename} to {tmp.name}") + tmpFiles = [] + docs = {} + + for file in files: + logging.info(f"Uploaded file {file.filename}") + + tmp = tempfile.NamedTemporaryFile(delete=True) + document_id = file.filename + path = tmp.name # Write the file to a temporary file tmp.write(await file.read()) tmp.flush() - # Load the document - documents.loadDocumentIntoWeaviate(document_type, document_id, tmp.name) + docs[document_id] = path - finally: + tmpFiles.append(tmp) + + documents.loadDocumentsIntoWeaviate(document_type, docs) + + for tmp in tmpFiles: tmp.close() - return {"document_id": document_id} + return { + "documents": [ key for key in docs ] + } @router.put("/{document_id}") async def update_document(document_id: str): """ Update a document """ + pass @router.delete("/{document_id}") async def delete_document(document_id: str): @@ -65,10 +73,23 @@ async def delete_document(document_id: str): pass @router.get("/find") -async def find_document(query: str): +async def find_document(query: str, limit: int = 10): """ Finds a document with content similar to given query """ - pass + + docs = documents.findSimilarDocuments(query, limit) + + response = [] + + # format response + for doc in docs: + response.append({ + "document_id": doc.metadata["document_name"], + "content": doc.page_content + }) + return response + +