Add query function
This commit is contained in:
parent
084e1cd383
commit
70a12c5820
23
docker-compose.yml
Normal file
23
docker-compose.yml
Normal file
@ -0,0 +1,23 @@
|
||||
---
|
||||
version: '3.4'
|
||||
services:
|
||||
weaviate:
|
||||
command:
|
||||
- --host
|
||||
- 0.0.0.0
|
||||
- --port
|
||||
- '12345'
|
||||
- --scheme
|
||||
- http
|
||||
image: semitechnologies/weaviate:1.19.7
|
||||
ports:
|
||||
- 12345:12345
|
||||
restart: on-failure:0
|
||||
environment:
|
||||
QUERY_DEFAULTS_LIMIT: 25
|
||||
AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED: 'true'
|
||||
PERSISTENCE_DATA_PATH: '/var/lib/weaviate'
|
||||
DEFAULT_VECTORIZER_MODULE: 'none'
|
||||
ENABLE_MODULES: ''
|
||||
CLUSTER_HOSTNAME: 'node1'
|
||||
...
|
@ -4,7 +4,7 @@ app:
|
||||
weaviate:
|
||||
url: "http://localhost:12345"
|
||||
text-field: "content"
|
||||
index-name: "knowledge"
|
||||
index-name: "Knowledge"
|
||||
schema-path: "schema.json"
|
||||
openai:
|
||||
url: "http://192.168.1.104:11111"
|
||||
|
@ -22,7 +22,38 @@ The text splitter client
|
||||
|
||||
db = None
|
||||
|
||||
def loadDocumentsIntoWeaviate(documentType: DocumentType,
|
||||
documents: dict) -> None:
|
||||
"""
|
||||
Loads a document into Weaviate.
|
||||
|
||||
Expects documents to be of formate {name -> path}
|
||||
"""
|
||||
texts = []
|
||||
|
||||
for name, path in documents.items():
|
||||
documents = None
|
||||
|
||||
if documentType == DocumentType.markdown:
|
||||
loader = UnstructuredMarkdownLoader(path)
|
||||
documents = loader.load()
|
||||
else:
|
||||
raise Exception("Document type not supported.")
|
||||
|
||||
# Split up the document
|
||||
newDocuments = splitDocuments(documents)
|
||||
for doc in newDocuments:
|
||||
doc.metadata["document_name"] = name
|
||||
|
||||
texts.extend(newDocuments)
|
||||
|
||||
db = getDatabase()
|
||||
|
||||
logging.info(f"Parsed {len(documents)} documents")
|
||||
|
||||
db.addDocuments(documents=texts)
|
||||
|
||||
logging.info(f"Loaded document {len(documents)} into Weaviate")
|
||||
|
||||
def loadDocumentIntoWeaviate(documentType: DocumentType,
|
||||
document_name : str,
|
||||
@ -30,37 +61,18 @@ def loadDocumentIntoWeaviate(documentType: DocumentType,
|
||||
"""
|
||||
Loads a document into Weaviate.
|
||||
"""
|
||||
global db
|
||||
loadDocumentsIntoWeaviate(documentType, {document_name: path})
|
||||
|
||||
documents = None
|
||||
|
||||
if documentType == DocumentType.markdown:
|
||||
loader = UnstructuredMarkdownLoader(path)
|
||||
documents = loader.load()
|
||||
else:
|
||||
raise Exception("Document type not supported.")
|
||||
|
||||
# Split up the document
|
||||
texts = splitDocument(documents)
|
||||
|
||||
for text in texts:
|
||||
text.metadata["document_name"] = document_name
|
||||
|
||||
if db is None:
|
||||
print(config["app.weaviate.url"])
|
||||
client = weaviate.Client(config["app.weaviate.url"])
|
||||
db = WeaviateClient(client,
|
||||
text_key=config["app.weaviate.text-field"],
|
||||
index_name=config["app.weaviate.index-name"])
|
||||
|
||||
db.addDocuments(documents=texts)
|
||||
|
||||
logging.info(f"Loaded document {len(documents)} into Weaviate")
|
||||
|
||||
|
||||
def splitDocument(document: Document | Iterable[Document]) -> List[Document]:
|
||||
def findSimilarDocuments(query: str, limit : int = 10) -> List[Document]:
|
||||
"""
|
||||
Splits a document into multiple documents using spaCy
|
||||
Finds similar documents to the query
|
||||
"""
|
||||
db = getDatabase()
|
||||
return db.similaritySearch(query=query, k=limit)
|
||||
|
||||
def splitDocuments(document: Iterable[Document]) -> List[Document]:
|
||||
"""
|
||||
Splits a document into multiple documents
|
||||
"""
|
||||
if document is None:
|
||||
raise Exception("Document is None")
|
||||
@ -69,3 +81,17 @@ def splitDocument(document: Document | Iterable[Document]) -> List[Document]:
|
||||
|
||||
return textSplitter.split_documents(document)
|
||||
|
||||
def getDatabase() -> WeaviateClient:
|
||||
"""
|
||||
Get a weaviate client instance
|
||||
"""
|
||||
global db
|
||||
|
||||
if db is None:
|
||||
logging.debug(config["app.weaviate.url"])
|
||||
client = weaviate.Client(config["app.weaviate.url"])
|
||||
db = WeaviateClient(client,
|
||||
text_key = config["app.weaviate.text-field"],
|
||||
index_name = config["app.weaviate.index-name"])
|
||||
|
||||
return db
|
||||
|
@ -47,15 +47,13 @@ class WeaviateClient:
|
||||
vector=vector
|
||||
)
|
||||
|
||||
|
||||
|
||||
def addDocument(self, document: Document) -> None:
|
||||
"""
|
||||
Adds a given document to the store
|
||||
"""
|
||||
self.addDocuments([document])
|
||||
|
||||
def similaritySearch(self, query: str, k: int = 10) -> List[str]:
|
||||
def similaritySearch(self, query: str, k: int = 10) -> List[Document]:
|
||||
"""
|
||||
Searches for similar documents
|
||||
|
||||
@ -63,7 +61,34 @@ class WeaviateClient:
|
||||
query: Text to lookup
|
||||
k: Number of results to return, default 10
|
||||
"""
|
||||
pass
|
||||
if self.embeddingProvider is None:
|
||||
raise Exception("No embedding provider set")
|
||||
vector = self.embeddingProvider.getEmbedding(query)
|
||||
return self.similaritySearchByVector(vector, k)
|
||||
|
||||
def similaritySearchByVector(self, vector: List[float], k: int = 10) -> List[Document]:
|
||||
"""
|
||||
Searches for similar documents
|
||||
|
||||
Args:
|
||||
vector: Vector to lookup
|
||||
k: Number of results to return, default 10
|
||||
"""
|
||||
vectorQuery = { "vector": vector }
|
||||
query_obj = self.client.query.get(self.index_name, ["content", "document_name"])
|
||||
result = query_obj.with_near_vector(vectorQuery).with_limit(k).do()
|
||||
|
||||
if "errors" in result:
|
||||
raise Exception(result["errors"])
|
||||
results = []
|
||||
|
||||
print(result)
|
||||
|
||||
for res in result["data"]["Get"][self.index_name]:
|
||||
text = res.pop(self.text_key)
|
||||
results.append(Document(page_content=text, metadata=res))
|
||||
|
||||
return results
|
||||
|
||||
def removeDocument(self, document_id: str) -> None:
|
||||
"""
|
||||
|
@ -1,4 +1,4 @@
|
||||
from fastapi import APIRouter, UploadFile
|
||||
from fastapi import APIRouter, UploadFile, responses
|
||||
from pydantic import BaseModel
|
||||
from enum import Enum
|
||||
import tempfile
|
||||
@ -7,55 +7,63 @@ from internal.documentTypes import DocumentType
|
||||
|
||||
from internal import documents
|
||||
|
||||
# set logger to debug
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/documents",
|
||||
tags=["documents"],
|
||||
responses={404: {"description": "Not found"}},
|
||||
)
|
||||
|
||||
@router.get("/{document_type}")
|
||||
async def read_documents(document_type: DocumentType):
|
||||
@router.get("/")
|
||||
async def read_documents():
|
||||
"""
|
||||
Get all documents
|
||||
"""
|
||||
pass
|
||||
|
||||
@router.get("/{document_id}")
|
||||
async def read_document(document_id: str):
|
||||
"""
|
||||
Get a specific document
|
||||
"""
|
||||
pass
|
||||
|
||||
@router.post("/")
|
||||
async def create_document(document_type: DocumentType,
|
||||
file: UploadFile):
|
||||
files: list[UploadFile]):
|
||||
"""
|
||||
Create a new document
|
||||
"""
|
||||
tmp = tempfile.NamedTemporaryFile(delete=True)
|
||||
document_id = file.filename
|
||||
|
||||
try:
|
||||
logging.info(f"Uploaded file {file.filename} to {tmp.name}")
|
||||
tmpFiles = []
|
||||
docs = {}
|
||||
|
||||
for file in files:
|
||||
logging.info(f"Uploaded file {file.filename}")
|
||||
|
||||
tmp = tempfile.NamedTemporaryFile(delete=True)
|
||||
document_id = file.filename
|
||||
path = tmp.name
|
||||
|
||||
# Write the file to a temporary file
|
||||
tmp.write(await file.read())
|
||||
tmp.flush()
|
||||
|
||||
# Load the document
|
||||
documents.loadDocumentIntoWeaviate(document_type, document_id, tmp.name)
|
||||
docs[document_id] = path
|
||||
|
||||
finally:
|
||||
tmpFiles.append(tmp)
|
||||
|
||||
documents.loadDocumentsIntoWeaviate(document_type, docs)
|
||||
|
||||
for tmp in tmpFiles:
|
||||
tmp.close()
|
||||
|
||||
return {"document_id": document_id}
|
||||
return {
|
||||
"documents": [ key for key in docs ]
|
||||
}
|
||||
|
||||
@router.put("/{document_id}")
|
||||
async def update_document(document_id: str):
|
||||
"""
|
||||
Update a document
|
||||
"""
|
||||
pass
|
||||
|
||||
@router.delete("/{document_id}")
|
||||
async def delete_document(document_id: str):
|
||||
@ -65,10 +73,23 @@ async def delete_document(document_id: str):
|
||||
pass
|
||||
|
||||
@router.get("/find")
|
||||
async def find_document(query: str):
|
||||
async def find_document(query: str, limit: int = 10):
|
||||
"""
|
||||
Finds a document with content similar to given query
|
||||
"""
|
||||
pass
|
||||
|
||||
docs = documents.findSimilarDocuments(query, limit)
|
||||
|
||||
response = []
|
||||
|
||||
# format response
|
||||
for doc in docs:
|
||||
response.append({
|
||||
"document_id": doc.metadata["document_name"],
|
||||
"content": doc.page_content
|
||||
})
|
||||
return response
|
||||
|
||||
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user