Add add document endpoint
This commit is contained in:
parent
f9f1982442
commit
084e1cd383
@ -2,6 +2,8 @@ import envyaml
|
||||
import os
|
||||
import logging
|
||||
from logging import Formatter, StreamHandler
|
||||
import weaviate
|
||||
import json
|
||||
|
||||
config = envyaml.EnvYAML(os.environ.get('CONFIG_PATH', 'config.yaml'))
|
||||
|
||||
@ -35,3 +37,15 @@ def initLogger() -> None:
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logging.warning("Invalid log level. Using INFO as default")
|
||||
|
||||
def initEnvironment() -> None:
|
||||
os.environ["OPENAI_PROXY"] = config["app.openai.url"]
|
||||
os.environ["OPENAI_API_KEY"] = config["app.openai.api_key"]
|
||||
|
||||
def initWeviate() -> None:
|
||||
logging.debug("Initializing Weaviate")
|
||||
client = weaviate.Client(config["app.weaviate.url"])
|
||||
with open(config["app.weaviate.schema-path"]) as file:
|
||||
schema = json.load(file)
|
||||
if not client.schema.contains(schema):
|
||||
logging.debug("Creating Weaviate schema")
|
||||
client.schema.create(schema)
|
||||
|
@ -1,6 +1,18 @@
|
||||
weaviate:
|
||||
url: "localhost:"
|
||||
|
||||
app:
|
||||
log:
|
||||
level: ""
|
||||
level: "debug"
|
||||
weaviate:
|
||||
url: "http://localhost:12345"
|
||||
text-field: "content"
|
||||
index-name: "knowledge"
|
||||
schema-path: "schema.json"
|
||||
openai:
|
||||
url: "http://192.168.1.104:11111"
|
||||
api_key: "sk-"
|
||||
document:
|
||||
# How many characters should each chunked document be split into?
|
||||
split_chunk_size: 500
|
||||
# How much overlap should each chunk have with its neighbor
|
||||
split_chunk_overlap: 20
|
||||
# What model in our OpenAPI api should we use?
|
||||
embeddings_model: "text-embedding"
|
||||
|
12
src/internal/documentTypes.py
Normal file
12
src/internal/documentTypes.py
Normal file
@ -0,0 +1,12 @@
|
||||
from enum import Enum
|
||||
|
||||
class DocumentType(str, Enum):
|
||||
"""
|
||||
Enumerated type for document types that we support
|
||||
"""
|
||||
markdown = "md"
|
||||
#html = "html"
|
||||
#pdf = "pdf"
|
||||
#epub = "epub"
|
||||
#odt = "odt"
|
||||
#docx = "docx"
|
@ -1,3 +1,71 @@
|
||||
from langchain.document_loaders import UnstructuredMarkdownLoader
|
||||
from langchain.schema import Document
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.vectorstores import Weaviate
|
||||
from .documentTypes import DocumentType
|
||||
from config import config
|
||||
import logging
|
||||
import weaviate
|
||||
from typing import Iterable, List
|
||||
from langchain.text_splitter import NLTKTextSplitter
|
||||
from internal.weaviate.weaviate import WeaviateClient
|
||||
|
||||
from langchain.text_splitter import SpacyTextSplitter
|
||||
# Globals
|
||||
# =======
|
||||
# The text splitter
|
||||
|
||||
textSplitter = NLTKTextSplitter(chunk_size=config["app.document.split_chunk_size"],
|
||||
chunk_overlap=config["app.document.split_chunk_overlap"])
|
||||
"""
|
||||
The text splitter client
|
||||
"""
|
||||
|
||||
db = None
|
||||
|
||||
|
||||
|
||||
def loadDocumentIntoWeaviate(documentType: DocumentType,
|
||||
document_name : str,
|
||||
path: str) -> None:
|
||||
"""
|
||||
Loads a document into Weaviate.
|
||||
"""
|
||||
global db
|
||||
|
||||
documents = None
|
||||
|
||||
if documentType == DocumentType.markdown:
|
||||
loader = UnstructuredMarkdownLoader(path)
|
||||
documents = loader.load()
|
||||
else:
|
||||
raise Exception("Document type not supported.")
|
||||
|
||||
# Split up the document
|
||||
texts = splitDocument(documents)
|
||||
|
||||
for text in texts:
|
||||
text.metadata["document_name"] = document_name
|
||||
|
||||
if db is None:
|
||||
print(config["app.weaviate.url"])
|
||||
client = weaviate.Client(config["app.weaviate.url"])
|
||||
db = WeaviateClient(client,
|
||||
text_key=config["app.weaviate.text-field"],
|
||||
index_name=config["app.weaviate.index-name"])
|
||||
|
||||
db.addDocuments(documents=texts)
|
||||
|
||||
logging.info(f"Loaded document {len(documents)} into Weaviate")
|
||||
|
||||
|
||||
def splitDocument(document: Document | Iterable[Document]) -> List[Document]:
|
||||
"""
|
||||
Splits a document into multiple documents using spaCy
|
||||
"""
|
||||
if document is None:
|
||||
raise Exception("Document is None")
|
||||
|
||||
global textSplitter
|
||||
|
||||
return textSplitter.split_documents(document)
|
||||
|
||||
|
11
src/internal/embeddings/EmbeddingProvider.py
Normal file
11
src/internal/embeddings/EmbeddingProvider.py
Normal file
@ -0,0 +1,11 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
class EmbeddingProvider(ABC):
|
||||
@abstractmethod
|
||||
def getEmbedding(self, word: str) -> List[float]:
|
||||
"""
|
||||
Returns the embedding for the given word
|
||||
"""
|
||||
pass
|
||||
|
48
src/internal/embeddings/OpenAIEmbeddingProvider.py
Normal file
48
src/internal/embeddings/OpenAIEmbeddingProvider.py
Normal file
@ -0,0 +1,48 @@
|
||||
from .EmbeddingProvider import EmbeddingProvider
|
||||
from typing import List
|
||||
import openai
|
||||
from config import config
|
||||
|
||||
class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.openai_api_key = config["app.openai.api_key"]
|
||||
self.openai_url = config["app.openai.url"]
|
||||
self.model = config["app.document.embeddings_model"]
|
||||
|
||||
def getEmbedding(self, text: str) -> List[float]:
|
||||
"""
|
||||
Returns the embedding for the given string
|
||||
"""
|
||||
openai.api_key = self.openai_api_key
|
||||
openai.api_base = self.openai_url
|
||||
return openai.Embedding.create(input = [text],
|
||||
model=self.model)['data'][0]['embedding']
|
||||
|
||||
@property
|
||||
def OPENAI_API_KEY(self):
|
||||
return self.openai_api_key
|
||||
|
||||
@OPENAI_API_KEY.setter
|
||||
def OPENAI_API_KEY(self, value):
|
||||
self.openai_api_key = value
|
||||
|
||||
@property
|
||||
def OPENAI_URL(self):
|
||||
return self.openai_url
|
||||
|
||||
@OPENAI_URL.setter
|
||||
def OPENAI_URL(self, value):
|
||||
self.openai_url = value
|
||||
|
||||
@property
|
||||
def MODEL(self):
|
||||
return self.model
|
||||
|
||||
@MODEL.setter
|
||||
def MODEL(self, value):
|
||||
self.model = value
|
||||
|
||||
|
||||
|
0
src/internal/embeddings/__init__.py
Normal file
0
src/internal/embeddings/__init__.py
Normal file
121
src/internal/weaviate/weaviate.py
Normal file
121
src/internal/weaviate/weaviate.py
Normal file
@ -0,0 +1,121 @@
|
||||
import weaviate
|
||||
import json
|
||||
import logging
|
||||
from typing import Iterable, List
|
||||
from internal.embeddings.OpenAIEmbeddingProvider import OpenAIEmbeddingProvider
|
||||
from internal.embeddings.EmbeddingProvider import EmbeddingProvider
|
||||
from langchain.schema import Document
|
||||
|
||||
class WeaviateClient:
|
||||
# Constructor
|
||||
def __init__(self, client: weaviate.Client,
|
||||
text_key: str = "content",
|
||||
index_name: str = "documents",
|
||||
embeddingProvider: EmbeddingProvider = OpenAIEmbeddingProvider(),
|
||||
):
|
||||
self.client = client
|
||||
self.embeddingProvider = embeddingProvider
|
||||
self.text_key = text_key
|
||||
self.index_name = index_name
|
||||
|
||||
@classmethod
|
||||
def fromUrl(cls, endpoint: str):
|
||||
"""
|
||||
Creates a WeaviateClient from an endpoint
|
||||
"""
|
||||
client = weaviate.Client(endpoint)
|
||||
return cls(client)
|
||||
|
||||
def addDocuments(self, documents: Iterable[Document]) -> None:
|
||||
"""
|
||||
Adds a list of documents to the store
|
||||
"""
|
||||
if self.embeddingProvider is None:
|
||||
raise Exception("No embedding provider set")
|
||||
|
||||
with self.client.batch as batch:
|
||||
for i, text in enumerate(documents):
|
||||
data_properties = {self.text_key: text.page_content}
|
||||
if text.metadata is not None:
|
||||
for key, val in text.metadata.items():
|
||||
data_properties[key] = val
|
||||
|
||||
vector = self.embeddingProvider.getEmbedding(text.page_content)
|
||||
batch.add_data_object(
|
||||
data_object=data_properties,
|
||||
class_name=self.index_name,
|
||||
vector=vector
|
||||
)
|
||||
|
||||
|
||||
|
||||
def addDocument(self, document: Document) -> None:
|
||||
"""
|
||||
Adds a given document to the store
|
||||
"""
|
||||
self.addDocuments([document])
|
||||
|
||||
def similaritySearch(self, query: str, k: int = 10) -> List[str]:
|
||||
"""
|
||||
Searches for similar documents
|
||||
|
||||
Args:
|
||||
query: Text to lookup
|
||||
k: Number of results to return, default 10
|
||||
"""
|
||||
pass
|
||||
|
||||
def removeDocument(self, document_id: str) -> None:
|
||||
"""
|
||||
Removes a document from the store
|
||||
|
||||
Args:
|
||||
document: Document to remove
|
||||
"""
|
||||
pass
|
||||
|
||||
def getDocument(self, document_id: str) -> str:
|
||||
"""
|
||||
Returns a document from the store
|
||||
|
||||
Args:
|
||||
document: Document to return
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@property
|
||||
def ENDPOINT(self):
|
||||
return self.endpoint
|
||||
|
||||
@ENDPOINT.setter
|
||||
def ENDPOINT(self, endpoint: str):
|
||||
self.endpoint = endpoint
|
||||
|
||||
@property
|
||||
def EMBEDDING_PROVIDER(self):
|
||||
return self.embeddingProvider
|
||||
|
||||
@EMBEDDING_PROVIDER.setter
|
||||
def EMBEDDING_PROVIDER(self, embeddingProvider: EmbeddingProvider):
|
||||
self.embeddingProvider = embeddingProvider
|
||||
|
||||
@property
|
||||
def TEXT_KEY(self):
|
||||
return self.text_key
|
||||
|
||||
@TEXT_KEY.setter
|
||||
def TEXT_KEY(self, text_key: str):
|
||||
self.text_key = text_key
|
||||
|
||||
@property
|
||||
def INDEX_NAME(self):
|
||||
return self.index_name
|
||||
|
||||
@INDEX_NAME.setter
|
||||
def INDEX_NAME(self, index_name: str):
|
||||
self.index_name = index_name
|
||||
|
||||
|
||||
|
||||
|
@ -4,18 +4,19 @@ from fastapi.responses import RedirectResponse
|
||||
from routers import question, documents
|
||||
|
||||
import logging
|
||||
from config import config, initLogger
|
||||
from config import config, initLogger, initEnvironment, initWeviate
|
||||
|
||||
initEnvironment()
|
||||
initLogger()
|
||||
|
||||
# Init weaviate, if not done already
|
||||
initWeviate()
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
app.include_router(question.router)
|
||||
app.include_router(documents.router)
|
||||
|
||||
logging.warn("Test message")
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""
|
||||
|
@ -1,6 +1,11 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import APIRouter, UploadFile
|
||||
from pydantic import BaseModel
|
||||
from enum import Enum
|
||||
import tempfile
|
||||
import logging
|
||||
from internal.documentTypes import DocumentType
|
||||
|
||||
from internal import documents
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/documents",
|
||||
@ -8,21 +13,6 @@ router = APIRouter(
|
||||
responses={404: {"description": "Not found"}},
|
||||
)
|
||||
|
||||
class Document(BaseModel):
|
||||
id: int
|
||||
title: str
|
||||
content: str
|
||||
|
||||
# Document Type enum
|
||||
class DocumentType(str, Enum):
|
||||
markdown = "md"
|
||||
#html = "html"
|
||||
#pdf = "pdf"
|
||||
epub = "epub"
|
||||
odt = "odt"
|
||||
docx = "docx"
|
||||
|
||||
|
||||
@router.get("/{document_type}")
|
||||
async def read_documents(document_type: DocumentType):
|
||||
"""
|
||||
@ -31,30 +21,54 @@ async def read_documents(document_type: DocumentType):
|
||||
pass
|
||||
|
||||
@router.get("/{document_id}")
|
||||
async def read_document(document_id: int):
|
||||
async def read_document(document_id: str):
|
||||
"""
|
||||
Get a specific document
|
||||
"""
|
||||
pass
|
||||
|
||||
@router.post("/")
|
||||
async def create_document(document: Document):
|
||||
async def create_document(document_type: DocumentType,
|
||||
file: UploadFile):
|
||||
"""
|
||||
Create a new document
|
||||
"""
|
||||
pass
|
||||
tmp = tempfile.NamedTemporaryFile(delete=True)
|
||||
document_id = file.filename
|
||||
|
||||
try:
|
||||
logging.info(f"Uploaded file {file.filename} to {tmp.name}")
|
||||
|
||||
# Write the file to a temporary file
|
||||
tmp.write(await file.read())
|
||||
tmp.flush()
|
||||
|
||||
# Load the document
|
||||
documents.loadDocumentIntoWeaviate(document_type, document_id, tmp.name)
|
||||
|
||||
finally:
|
||||
tmp.close()
|
||||
|
||||
return {"document_id": document_id}
|
||||
|
||||
@router.put("/{document_id}")
|
||||
async def update_document(document_id: int, document: Document):
|
||||
async def update_document(document_id: str):
|
||||
"""
|
||||
Update a document
|
||||
"""
|
||||
pass
|
||||
|
||||
@router.delete("/{document_id}")
|
||||
async def delete_document(document_id: int):
|
||||
async def delete_document(document_id: str):
|
||||
"""
|
||||
Delete a document
|
||||
"""
|
||||
pass
|
||||
|
||||
@router.get("/find")
|
||||
async def find_document(query: str):
|
||||
"""
|
||||
Finds a document with content similar to given query
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
|
@ -22,4 +22,3 @@ async def ask_question(question: Question, conversation_id: str):
|
||||
@router.get("/{conversation_id}}")
|
||||
async def get_question_history(conversation_id: str):
|
||||
return {"message": f"Hello question {conversation_id}!"}
|
||||
|
||||
|
21
src/schema.json
Normal file
21
src/schema.json
Normal file
@ -0,0 +1,21 @@
|
||||
{
|
||||
"classes": [
|
||||
{
|
||||
"class": "knowledge",
|
||||
"description": "Knowledge for the language models",
|
||||
"vectorizer": "none",
|
||||
"properties": [
|
||||
{
|
||||
"name": "content",
|
||||
"description": "The content of the document",
|
||||
"dataType": ["text"]
|
||||
},
|
||||
{
|
||||
"name": "document_id",
|
||||
"description": "The id of the document, user facing",
|
||||
"dataType": ["text"]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
Loading…
Reference in New Issue
Block a user