Add add document endpoint
This commit is contained in:
parent
f9f1982442
commit
084e1cd383
@ -2,6 +2,8 @@ import envyaml
|
|||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
from logging import Formatter, StreamHandler
|
from logging import Formatter, StreamHandler
|
||||||
|
import weaviate
|
||||||
|
import json
|
||||||
|
|
||||||
config = envyaml.EnvYAML(os.environ.get('CONFIG_PATH', 'config.yaml'))
|
config = envyaml.EnvYAML(os.environ.get('CONFIG_PATH', 'config.yaml'))
|
||||||
|
|
||||||
@ -35,3 +37,15 @@ def initLogger() -> None:
|
|||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logging.warning("Invalid log level. Using INFO as default")
|
logging.warning("Invalid log level. Using INFO as default")
|
||||||
|
|
||||||
|
def initEnvironment() -> None:
|
||||||
|
os.environ["OPENAI_PROXY"] = config["app.openai.url"]
|
||||||
|
os.environ["OPENAI_API_KEY"] = config["app.openai.api_key"]
|
||||||
|
|
||||||
|
def initWeviate() -> None:
|
||||||
|
logging.debug("Initializing Weaviate")
|
||||||
|
client = weaviate.Client(config["app.weaviate.url"])
|
||||||
|
with open(config["app.weaviate.schema-path"]) as file:
|
||||||
|
schema = json.load(file)
|
||||||
|
if not client.schema.contains(schema):
|
||||||
|
logging.debug("Creating Weaviate schema")
|
||||||
|
client.schema.create(schema)
|
||||||
|
@ -1,6 +1,18 @@
|
|||||||
weaviate:
|
|
||||||
url: "localhost:"
|
|
||||||
|
|
||||||
app:
|
app:
|
||||||
log:
|
log:
|
||||||
level: ""
|
level: "debug"
|
||||||
|
weaviate:
|
||||||
|
url: "http://localhost:12345"
|
||||||
|
text-field: "content"
|
||||||
|
index-name: "knowledge"
|
||||||
|
schema-path: "schema.json"
|
||||||
|
openai:
|
||||||
|
url: "http://192.168.1.104:11111"
|
||||||
|
api_key: "sk-"
|
||||||
|
document:
|
||||||
|
# How many characters should each chunked document be split into?
|
||||||
|
split_chunk_size: 500
|
||||||
|
# How much overlap should each chunk have with its neighbor
|
||||||
|
split_chunk_overlap: 20
|
||||||
|
# What model in our OpenAPI api should we use?
|
||||||
|
embeddings_model: "text-embedding"
|
||||||
|
12
src/internal/documentTypes.py
Normal file
12
src/internal/documentTypes.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
class DocumentType(str, Enum):
|
||||||
|
"""
|
||||||
|
Enumerated type for document types that we support
|
||||||
|
"""
|
||||||
|
markdown = "md"
|
||||||
|
#html = "html"
|
||||||
|
#pdf = "pdf"
|
||||||
|
#epub = "epub"
|
||||||
|
#odt = "odt"
|
||||||
|
#docx = "docx"
|
@ -1,3 +1,71 @@
|
|||||||
|
from langchain.document_loaders import UnstructuredMarkdownLoader
|
||||||
|
from langchain.schema import Document
|
||||||
|
from langchain.embeddings import OpenAIEmbeddings
|
||||||
|
from langchain.vectorstores import Weaviate
|
||||||
|
from .documentTypes import DocumentType
|
||||||
|
from config import config
|
||||||
|
import logging
|
||||||
|
import weaviate
|
||||||
|
from typing import Iterable, List
|
||||||
|
from langchain.text_splitter import NLTKTextSplitter
|
||||||
|
from internal.weaviate.weaviate import WeaviateClient
|
||||||
|
|
||||||
from langchain.text_splitter import SpacyTextSplitter
|
# Globals
|
||||||
|
# =======
|
||||||
|
# The text splitter
|
||||||
|
|
||||||
|
textSplitter = NLTKTextSplitter(chunk_size=config["app.document.split_chunk_size"],
|
||||||
|
chunk_overlap=config["app.document.split_chunk_overlap"])
|
||||||
|
"""
|
||||||
|
The text splitter client
|
||||||
|
"""
|
||||||
|
|
||||||
|
db = None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def loadDocumentIntoWeaviate(documentType: DocumentType,
|
||||||
|
document_name : str,
|
||||||
|
path: str) -> None:
|
||||||
|
"""
|
||||||
|
Loads a document into Weaviate.
|
||||||
|
"""
|
||||||
|
global db
|
||||||
|
|
||||||
|
documents = None
|
||||||
|
|
||||||
|
if documentType == DocumentType.markdown:
|
||||||
|
loader = UnstructuredMarkdownLoader(path)
|
||||||
|
documents = loader.load()
|
||||||
|
else:
|
||||||
|
raise Exception("Document type not supported.")
|
||||||
|
|
||||||
|
# Split up the document
|
||||||
|
texts = splitDocument(documents)
|
||||||
|
|
||||||
|
for text in texts:
|
||||||
|
text.metadata["document_name"] = document_name
|
||||||
|
|
||||||
|
if db is None:
|
||||||
|
print(config["app.weaviate.url"])
|
||||||
|
client = weaviate.Client(config["app.weaviate.url"])
|
||||||
|
db = WeaviateClient(client,
|
||||||
|
text_key=config["app.weaviate.text-field"],
|
||||||
|
index_name=config["app.weaviate.index-name"])
|
||||||
|
|
||||||
|
db.addDocuments(documents=texts)
|
||||||
|
|
||||||
|
logging.info(f"Loaded document {len(documents)} into Weaviate")
|
||||||
|
|
||||||
|
|
||||||
|
def splitDocument(document: Document | Iterable[Document]) -> List[Document]:
|
||||||
|
"""
|
||||||
|
Splits a document into multiple documents using spaCy
|
||||||
|
"""
|
||||||
|
if document is None:
|
||||||
|
raise Exception("Document is None")
|
||||||
|
|
||||||
|
global textSplitter
|
||||||
|
|
||||||
|
return textSplitter.split_documents(document)
|
||||||
|
|
||||||
|
11
src/internal/embeddings/EmbeddingProvider.py
Normal file
11
src/internal/embeddings/EmbeddingProvider.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
class EmbeddingProvider(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def getEmbedding(self, word: str) -> List[float]:
|
||||||
|
"""
|
||||||
|
Returns the embedding for the given word
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
48
src/internal/embeddings/OpenAIEmbeddingProvider.py
Normal file
48
src/internal/embeddings/OpenAIEmbeddingProvider.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
from .EmbeddingProvider import EmbeddingProvider
|
||||||
|
from typing import List
|
||||||
|
import openai
|
||||||
|
from config import config
|
||||||
|
|
||||||
|
class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.openai_api_key = config["app.openai.api_key"]
|
||||||
|
self.openai_url = config["app.openai.url"]
|
||||||
|
self.model = config["app.document.embeddings_model"]
|
||||||
|
|
||||||
|
def getEmbedding(self, text: str) -> List[float]:
|
||||||
|
"""
|
||||||
|
Returns the embedding for the given string
|
||||||
|
"""
|
||||||
|
openai.api_key = self.openai_api_key
|
||||||
|
openai.api_base = self.openai_url
|
||||||
|
return openai.Embedding.create(input = [text],
|
||||||
|
model=self.model)['data'][0]['embedding']
|
||||||
|
|
||||||
|
@property
|
||||||
|
def OPENAI_API_KEY(self):
|
||||||
|
return self.openai_api_key
|
||||||
|
|
||||||
|
@OPENAI_API_KEY.setter
|
||||||
|
def OPENAI_API_KEY(self, value):
|
||||||
|
self.openai_api_key = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def OPENAI_URL(self):
|
||||||
|
return self.openai_url
|
||||||
|
|
||||||
|
@OPENAI_URL.setter
|
||||||
|
def OPENAI_URL(self, value):
|
||||||
|
self.openai_url = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def MODEL(self):
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
@MODEL.setter
|
||||||
|
def MODEL(self, value):
|
||||||
|
self.model = value
|
||||||
|
|
||||||
|
|
||||||
|
|
0
src/internal/embeddings/__init__.py
Normal file
0
src/internal/embeddings/__init__.py
Normal file
121
src/internal/weaviate/weaviate.py
Normal file
121
src/internal/weaviate/weaviate.py
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
import weaviate
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Iterable, List
|
||||||
|
from internal.embeddings.OpenAIEmbeddingProvider import OpenAIEmbeddingProvider
|
||||||
|
from internal.embeddings.EmbeddingProvider import EmbeddingProvider
|
||||||
|
from langchain.schema import Document
|
||||||
|
|
||||||
|
class WeaviateClient:
|
||||||
|
# Constructor
|
||||||
|
def __init__(self, client: weaviate.Client,
|
||||||
|
text_key: str = "content",
|
||||||
|
index_name: str = "documents",
|
||||||
|
embeddingProvider: EmbeddingProvider = OpenAIEmbeddingProvider(),
|
||||||
|
):
|
||||||
|
self.client = client
|
||||||
|
self.embeddingProvider = embeddingProvider
|
||||||
|
self.text_key = text_key
|
||||||
|
self.index_name = index_name
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def fromUrl(cls, endpoint: str):
|
||||||
|
"""
|
||||||
|
Creates a WeaviateClient from an endpoint
|
||||||
|
"""
|
||||||
|
client = weaviate.Client(endpoint)
|
||||||
|
return cls(client)
|
||||||
|
|
||||||
|
def addDocuments(self, documents: Iterable[Document]) -> None:
|
||||||
|
"""
|
||||||
|
Adds a list of documents to the store
|
||||||
|
"""
|
||||||
|
if self.embeddingProvider is None:
|
||||||
|
raise Exception("No embedding provider set")
|
||||||
|
|
||||||
|
with self.client.batch as batch:
|
||||||
|
for i, text in enumerate(documents):
|
||||||
|
data_properties = {self.text_key: text.page_content}
|
||||||
|
if text.metadata is not None:
|
||||||
|
for key, val in text.metadata.items():
|
||||||
|
data_properties[key] = val
|
||||||
|
|
||||||
|
vector = self.embeddingProvider.getEmbedding(text.page_content)
|
||||||
|
batch.add_data_object(
|
||||||
|
data_object=data_properties,
|
||||||
|
class_name=self.index_name,
|
||||||
|
vector=vector
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def addDocument(self, document: Document) -> None:
|
||||||
|
"""
|
||||||
|
Adds a given document to the store
|
||||||
|
"""
|
||||||
|
self.addDocuments([document])
|
||||||
|
|
||||||
|
def similaritySearch(self, query: str, k: int = 10) -> List[str]:
|
||||||
|
"""
|
||||||
|
Searches for similar documents
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Text to lookup
|
||||||
|
k: Number of results to return, default 10
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def removeDocument(self, document_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Removes a document from the store
|
||||||
|
|
||||||
|
Args:
|
||||||
|
document: Document to remove
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def getDocument(self, document_id: str) -> str:
|
||||||
|
"""
|
||||||
|
Returns a document from the store
|
||||||
|
|
||||||
|
Args:
|
||||||
|
document: Document to return
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ENDPOINT(self):
|
||||||
|
return self.endpoint
|
||||||
|
|
||||||
|
@ENDPOINT.setter
|
||||||
|
def ENDPOINT(self, endpoint: str):
|
||||||
|
self.endpoint = endpoint
|
||||||
|
|
||||||
|
@property
|
||||||
|
def EMBEDDING_PROVIDER(self):
|
||||||
|
return self.embeddingProvider
|
||||||
|
|
||||||
|
@EMBEDDING_PROVIDER.setter
|
||||||
|
def EMBEDDING_PROVIDER(self, embeddingProvider: EmbeddingProvider):
|
||||||
|
self.embeddingProvider = embeddingProvider
|
||||||
|
|
||||||
|
@property
|
||||||
|
def TEXT_KEY(self):
|
||||||
|
return self.text_key
|
||||||
|
|
||||||
|
@TEXT_KEY.setter
|
||||||
|
def TEXT_KEY(self, text_key: str):
|
||||||
|
self.text_key = text_key
|
||||||
|
|
||||||
|
@property
|
||||||
|
def INDEX_NAME(self):
|
||||||
|
return self.index_name
|
||||||
|
|
||||||
|
@INDEX_NAME.setter
|
||||||
|
def INDEX_NAME(self, index_name: str):
|
||||||
|
self.index_name = index_name
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -4,18 +4,19 @@ from fastapi.responses import RedirectResponse
|
|||||||
from routers import question, documents
|
from routers import question, documents
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from config import config, initLogger
|
from config import config, initLogger, initEnvironment, initWeviate
|
||||||
|
|
||||||
|
initEnvironment()
|
||||||
initLogger()
|
initLogger()
|
||||||
|
|
||||||
|
# Init weaviate, if not done already
|
||||||
|
initWeviate()
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
app.include_router(question.router)
|
app.include_router(question.router)
|
||||||
app.include_router(documents.router)
|
app.include_router(documents.router)
|
||||||
|
|
||||||
logging.warn("Test message")
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
async def root():
|
async def root():
|
||||||
"""
|
"""
|
||||||
|
@ -1,6 +1,11 @@
|
|||||||
from fastapi import APIRouter
|
from fastapi import APIRouter, UploadFile
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
import tempfile
|
||||||
|
import logging
|
||||||
|
from internal.documentTypes import DocumentType
|
||||||
|
|
||||||
|
from internal import documents
|
||||||
|
|
||||||
router = APIRouter(
|
router = APIRouter(
|
||||||
prefix="/documents",
|
prefix="/documents",
|
||||||
@ -8,21 +13,6 @@ router = APIRouter(
|
|||||||
responses={404: {"description": "Not found"}},
|
responses={404: {"description": "Not found"}},
|
||||||
)
|
)
|
||||||
|
|
||||||
class Document(BaseModel):
|
|
||||||
id: int
|
|
||||||
title: str
|
|
||||||
content: str
|
|
||||||
|
|
||||||
# Document Type enum
|
|
||||||
class DocumentType(str, Enum):
|
|
||||||
markdown = "md"
|
|
||||||
#html = "html"
|
|
||||||
#pdf = "pdf"
|
|
||||||
epub = "epub"
|
|
||||||
odt = "odt"
|
|
||||||
docx = "docx"
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{document_type}")
|
@router.get("/{document_type}")
|
||||||
async def read_documents(document_type: DocumentType):
|
async def read_documents(document_type: DocumentType):
|
||||||
"""
|
"""
|
||||||
@ -31,30 +21,54 @@ async def read_documents(document_type: DocumentType):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@router.get("/{document_id}")
|
@router.get("/{document_id}")
|
||||||
async def read_document(document_id: int):
|
async def read_document(document_id: str):
|
||||||
"""
|
"""
|
||||||
Get a specific document
|
Get a specific document
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@router.post("/")
|
@router.post("/")
|
||||||
async def create_document(document: Document):
|
async def create_document(document_type: DocumentType,
|
||||||
|
file: UploadFile):
|
||||||
"""
|
"""
|
||||||
Create a new document
|
Create a new document
|
||||||
"""
|
"""
|
||||||
pass
|
tmp = tempfile.NamedTemporaryFile(delete=True)
|
||||||
|
document_id = file.filename
|
||||||
|
|
||||||
|
try:
|
||||||
|
logging.info(f"Uploaded file {file.filename} to {tmp.name}")
|
||||||
|
|
||||||
|
# Write the file to a temporary file
|
||||||
|
tmp.write(await file.read())
|
||||||
|
tmp.flush()
|
||||||
|
|
||||||
|
# Load the document
|
||||||
|
documents.loadDocumentIntoWeaviate(document_type, document_id, tmp.name)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
tmp.close()
|
||||||
|
|
||||||
|
return {"document_id": document_id}
|
||||||
|
|
||||||
@router.put("/{document_id}")
|
@router.put("/{document_id}")
|
||||||
async def update_document(document_id: int, document: Document):
|
async def update_document(document_id: str):
|
||||||
"""
|
"""
|
||||||
Update a document
|
Update a document
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
@router.delete("/{document_id}")
|
@router.delete("/{document_id}")
|
||||||
async def delete_document(document_id: int):
|
async def delete_document(document_id: str):
|
||||||
"""
|
"""
|
||||||
Delete a document
|
Delete a document
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@router.get("/find")
|
||||||
|
async def find_document(query: str):
|
||||||
|
"""
|
||||||
|
Finds a document with content similar to given query
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@ -22,4 +22,3 @@ async def ask_question(question: Question, conversation_id: str):
|
|||||||
@router.get("/{conversation_id}}")
|
@router.get("/{conversation_id}}")
|
||||||
async def get_question_history(conversation_id: str):
|
async def get_question_history(conversation_id: str):
|
||||||
return {"message": f"Hello question {conversation_id}!"}
|
return {"message": f"Hello question {conversation_id}!"}
|
||||||
|
|
||||||
|
21
src/schema.json
Normal file
21
src/schema.json
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
{
|
||||||
|
"classes": [
|
||||||
|
{
|
||||||
|
"class": "knowledge",
|
||||||
|
"description": "Knowledge for the language models",
|
||||||
|
"vectorizer": "none",
|
||||||
|
"properties": [
|
||||||
|
{
|
||||||
|
"name": "content",
|
||||||
|
"description": "The content of the document",
|
||||||
|
"dataType": ["text"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "document_id",
|
||||||
|
"description": "The id of the document, user facing",
|
||||||
|
"dataType": ["text"]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user