From b664edde292210d66b5f05c4ac5069d9123d1b38 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Thu, 25 Apr 2024 00:19:02 +0200 Subject: [PATCH] feat(rerankers): Add new backend, support jina rerankers API (#2121) Signed-off-by: Ettore Di Giacinto --- .github/workflows/test-extra.yml | 31 +++++ Dockerfile | 5 +- Makefile | 13 +- aio/cpu/rerank.yaml | 27 ++++ aio/entrypoint.sh | 2 +- aio/gpu-8g/rerank.yaml | 27 ++++ aio/intel/rerank.yaml | 27 ++++ backend/backend.proto | 24 ++++ .../transformers/transformers-nvidia.yml | 2 + .../transformers/transformers-rocm.yml | 2 + .../common-env/transformers/transformers.yml | 4 +- backend/python/rerankers/Makefile | 27 ++++ backend/python/rerankers/README.md | 5 + backend/python/rerankers/reranker.py | 123 ++++++++++++++++++ backend/python/rerankers/run.sh | 14 ++ backend/python/rerankers/test.sh | 11 ++ backend/python/rerankers/test_reranker.py | 90 +++++++++++++ core/backend/rerank.go | 39 ++++++ core/http/app.go | 1 + core/http/endpoints/jina/rerank.go | 84 ++++++++++++ core/http/routes/jina.go | 19 +++ core/schema/jina.go | 34 +++++ pkg/grpc/backend.go | 2 + pkg/grpc/client.go | 16 +++ pkg/grpc/embed.go | 4 + 25 files changed, 628 insertions(+), 5 deletions(-) create mode 100644 aio/cpu/rerank.yaml create mode 100644 aio/gpu-8g/rerank.yaml create mode 100644 aio/intel/rerank.yaml create mode 100644 backend/python/rerankers/Makefile create mode 100644 backend/python/rerankers/README.md create mode 100755 backend/python/rerankers/reranker.py create mode 100755 backend/python/rerankers/run.sh create mode 100755 backend/python/rerankers/test.sh create mode 100755 backend/python/rerankers/test_reranker.py create mode 100644 core/backend/rerank.go create mode 100644 core/http/endpoints/jina/rerank.go create mode 100644 core/http/routes/jina.go create mode 100644 core/schema/jina.go diff --git a/.github/workflows/test-extra.yml b/.github/workflows/test-extra.yml index fa45cb3c..f9476d4d 100644 --- a/.github/workflows/test-extra.yml +++ b/.github/workflows/test-extra.yml @@ -74,6 +74,37 @@ jobs: make --jobs=5 --output-sync=target -C backend/python/sentencetransformers make --jobs=5 --output-sync=target -C backend/python/sentencetransformers test + + tests-rerankers: + runs-on: ubuntu-latest + steps: + - name: Clone + uses: actions/checkout@v4 + with: + submodules: true + - name: Dependencies + run: | + sudo apt-get update + sudo apt-get install build-essential ffmpeg + curl https://repo.anaconda.com/pkgs/misc/gpgkeys/anaconda.asc | gpg --dearmor > conda.gpg && \ + sudo install -o root -g root -m 644 conda.gpg /usr/share/keyrings/conda-archive-keyring.gpg && \ + gpg --keyring /usr/share/keyrings/conda-archive-keyring.gpg --no-default-keyring --fingerprint 34161F5BF5EB1D4BFBBB8F0A8AEB4F8B29D82806 && \ + sudo /bin/bash -c 'echo "deb [arch=amd64 signed-by=/usr/share/keyrings/conda-archive-keyring.gpg] https://repo.anaconda.com/pkgs/misc/debrepo/conda stable main" > /etc/apt/sources.list.d/conda.list' && \ + sudo /bin/bash -c 'echo "deb [arch=amd64 signed-by=/usr/share/keyrings/conda-archive-keyring.gpg] https://repo.anaconda.com/pkgs/misc/debrepo/conda stable main" | tee -a /etc/apt/sources.list.d/conda.list' && \ + sudo apt-get update && \ + sudo apt-get install -y conda + sudo apt-get install -y ca-certificates cmake curl patch python3-pip + sudo apt-get install -y libopencv-dev + pip install --user grpcio-tools + + sudo rm -rfv /usr/bin/conda || true + + - name: Test rerankers + run: | + export PATH=$PATH:/opt/conda/bin + make --jobs=5 --output-sync=target -C backend/python/rerankers + make --jobs=5 --output-sync=target -C backend/python/rerankers test + tests-diffusers: runs-on: ubuntu-latest steps: diff --git a/Dockerfile b/Dockerfile index 4bc8b35e..4d12cb56 100644 --- a/Dockerfile +++ b/Dockerfile @@ -16,7 +16,7 @@ ARG TARGETVARIANT ENV BUILD_TYPE=${BUILD_TYPE} ENV DEBIAN_FRONTEND=noninteractive -ENV EXTERNAL_GRPC_BACKENDS="coqui:/build/backend/python/coqui/run.sh,huggingface-embeddings:/build/backend/python/sentencetransformers/run.sh,petals:/build/backend/python/petals/run.sh,transformers:/build/backend/python/transformers/run.sh,sentencetransformers:/build/backend/python/sentencetransformers/run.sh,autogptq:/build/backend/python/autogptq/run.sh,bark:/build/backend/python/bark/run.sh,diffusers:/build/backend/python/diffusers/run.sh,exllama:/build/backend/python/exllama/run.sh,vall-e-x:/build/backend/python/vall-e-x/run.sh,vllm:/build/backend/python/vllm/run.sh,mamba:/build/backend/python/mamba/run.sh,exllama2:/build/backend/python/exllama2/run.sh,transformers-musicgen:/build/backend/python/transformers-musicgen/run.sh,parler-tts:/build/backend/python/parler-tts/run.sh" +ENV EXTERNAL_GRPC_BACKENDS="coqui:/build/backend/python/coqui/run.sh,huggingface-embeddings:/build/backend/python/sentencetransformers/run.sh,petals:/build/backend/python/petals/run.sh,transformers:/build/backend/python/transformers/run.sh,sentencetransformers:/build/backend/python/sentencetransformers/run.sh,rerankers:/build/backend/python/rerankers/run.sh,autogptq:/build/backend/python/autogptq/run.sh,bark:/build/backend/python/bark/run.sh,diffusers:/build/backend/python/diffusers/run.sh,exllama:/build/backend/python/exllama/run.sh,vall-e-x:/build/backend/python/vall-e-x/run.sh,vllm:/build/backend/python/vllm/run.sh,mamba:/build/backend/python/mamba/run.sh,exllama2:/build/backend/python/exllama2/run.sh,transformers-musicgen:/build/backend/python/transformers-musicgen/run.sh,parler-tts:/build/backend/python/parler-tts/run.sh" ARG GO_TAGS="stablediffusion tinydream tts" @@ -259,6 +259,9 @@ RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \ RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \ make -C backend/python/sentencetransformers \ ; fi +RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \ + make -C backend/python/rerankers \ + ; fi RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \ make -C backend/python/transformers \ ; fi diff --git a/Makefile b/Makefile index 662e54bd..b017982e 100644 --- a/Makefile +++ b/Makefile @@ -437,10 +437,10 @@ protogen-go-clean: $(RM) bin/* .PHONY: protogen-python -protogen-python: autogptq-protogen bark-protogen coqui-protogen diffusers-protogen exllama-protogen exllama2-protogen mamba-protogen petals-protogen sentencetransformers-protogen transformers-protogen parler-tts-protogen transformers-musicgen-protogen vall-e-x-protogen vllm-protogen +protogen-python: autogptq-protogen bark-protogen coqui-protogen diffusers-protogen exllama-protogen exllama2-protogen mamba-protogen petals-protogen rerankers-protogen sentencetransformers-protogen transformers-protogen parler-tts-protogen transformers-musicgen-protogen vall-e-x-protogen vllm-protogen .PHONY: protogen-python-clean -protogen-python-clean: autogptq-protogen-clean bark-protogen-clean coqui-protogen-clean diffusers-protogen-clean exllama-protogen-clean exllama2-protogen-clean mamba-protogen-clean petals-protogen-clean sentencetransformers-protogen-clean transformers-protogen-clean transformers-musicgen-protogen-clean parler-tts-protogen-clean vall-e-x-protogen-clean vllm-protogen-clean +protogen-python-clean: autogptq-protogen-clean bark-protogen-clean coqui-protogen-clean diffusers-protogen-clean exllama-protogen-clean exllama2-protogen-clean mamba-protogen-clean petals-protogen-clean sentencetransformers-protogen-clean rerankers-protogen-clean transformers-protogen-clean transformers-musicgen-protogen-clean parler-tts-protogen-clean vall-e-x-protogen-clean vllm-protogen-clean .PHONY: autogptq-protogen autogptq-protogen: @@ -506,6 +506,14 @@ petals-protogen: petals-protogen-clean: $(MAKE) -C backend/python/petals protogen-clean +.PHONY: rerankers-protogen +rerankers-protogen: + $(MAKE) -C backend/python/rerankers protogen + +.PHONY: rerankers-protogen-clean +rerankers-protogen-clean: + $(MAKE) -C backend/python/rerankers protogen-clean + .PHONY: sentencetransformers-protogen sentencetransformers-protogen: $(MAKE) -C backend/python/sentencetransformers protogen @@ -564,6 +572,7 @@ prepare-extra-conda-environments: protogen-python $(MAKE) -C backend/python/vllm $(MAKE) -C backend/python/mamba $(MAKE) -C backend/python/sentencetransformers + $(MAKE) -C backend/python/rerankers $(MAKE) -C backend/python/transformers $(MAKE) -C backend/python/transformers-musicgen $(MAKE) -C backend/python/parler-tts diff --git a/aio/cpu/rerank.yaml b/aio/cpu/rerank.yaml new file mode 100644 index 00000000..b84755a8 --- /dev/null +++ b/aio/cpu/rerank.yaml @@ -0,0 +1,27 @@ +name: jina-reranker-v1-base-en +backend: rerankers +parameters: + model: cross-encoder + +usage: | + You can test this model with curl like this: + + curl http://localhost:8080/v1/rerank \ + -H "Content-Type: application/json" \ + -d '{ + "model": "jina-reranker-v1-base-en", + "query": "Organic skincare products for sensitive skin", + "documents": [ + "Eco-friendly kitchenware for modern homes", + "Biodegradable cleaning supplies for eco-conscious consumers", + "Organic cotton baby clothes for sensitive skin", + "Natural organic skincare range for sensitive skin", + "Tech gadgets for smart homes: 2024 edition", + "Sustainable gardening tools and compost solutions", + "Sensitive skin-friendly facial cleansers and toners", + "Organic food wraps and storage solutions", + "All-natural pet food for dogs with allergies", + "Yoga mats made from recycled materials" + ], + "top_n": 3 + }' diff --git a/aio/entrypoint.sh b/aio/entrypoint.sh index 5fd8d9c2..2487e64f 100755 --- a/aio/entrypoint.sh +++ b/aio/entrypoint.sh @@ -129,7 +129,7 @@ detect_gpu detect_gpu_size PROFILE="${PROFILE:-$GPU_SIZE}" # default to cpu -export MODELS="${MODELS:-/aio/${PROFILE}/embeddings.yaml,/aio/${PROFILE}/text-to-speech.yaml,/aio/${PROFILE}/image-gen.yaml,/aio/${PROFILE}/text-to-text.yaml,/aio/${PROFILE}/speech-to-text.yaml,/aio/${PROFILE}/vision.yaml}" +export MODELS="${MODELS:-/aio/${PROFILE}/embeddings.yaml,/aio/${PROFILE}/rerank.yaml,/aio/${PROFILE}/text-to-speech.yaml,/aio/${PROFILE}/image-gen.yaml,/aio/${PROFILE}/text-to-text.yaml,/aio/${PROFILE}/speech-to-text.yaml,/aio/${PROFILE}/vision.yaml}" check_vars diff --git a/aio/gpu-8g/rerank.yaml b/aio/gpu-8g/rerank.yaml new file mode 100644 index 00000000..b84755a8 --- /dev/null +++ b/aio/gpu-8g/rerank.yaml @@ -0,0 +1,27 @@ +name: jina-reranker-v1-base-en +backend: rerankers +parameters: + model: cross-encoder + +usage: | + You can test this model with curl like this: + + curl http://localhost:8080/v1/rerank \ + -H "Content-Type: application/json" \ + -d '{ + "model": "jina-reranker-v1-base-en", + "query": "Organic skincare products for sensitive skin", + "documents": [ + "Eco-friendly kitchenware for modern homes", + "Biodegradable cleaning supplies for eco-conscious consumers", + "Organic cotton baby clothes for sensitive skin", + "Natural organic skincare range for sensitive skin", + "Tech gadgets for smart homes: 2024 edition", + "Sustainable gardening tools and compost solutions", + "Sensitive skin-friendly facial cleansers and toners", + "Organic food wraps and storage solutions", + "All-natural pet food for dogs with allergies", + "Yoga mats made from recycled materials" + ], + "top_n": 3 + }' diff --git a/aio/intel/rerank.yaml b/aio/intel/rerank.yaml new file mode 100644 index 00000000..b84755a8 --- /dev/null +++ b/aio/intel/rerank.yaml @@ -0,0 +1,27 @@ +name: jina-reranker-v1-base-en +backend: rerankers +parameters: + model: cross-encoder + +usage: | + You can test this model with curl like this: + + curl http://localhost:8080/v1/rerank \ + -H "Content-Type: application/json" \ + -d '{ + "model": "jina-reranker-v1-base-en", + "query": "Organic skincare products for sensitive skin", + "documents": [ + "Eco-friendly kitchenware for modern homes", + "Biodegradable cleaning supplies for eco-conscious consumers", + "Organic cotton baby clothes for sensitive skin", + "Natural organic skincare range for sensitive skin", + "Tech gadgets for smart homes: 2024 edition", + "Sustainable gardening tools and compost solutions", + "Sensitive skin-friendly facial cleansers and toners", + "Organic food wraps and storage solutions", + "All-natural pet food for dogs with allergies", + "Yoga mats made from recycled materials" + ], + "top_n": 3 + }' diff --git a/backend/backend.proto b/backend/backend.proto index ec01e4a7..778a96ff 100644 --- a/backend/backend.proto +++ b/backend/backend.proto @@ -23,6 +23,30 @@ service Backend { rpc StoresDelete(StoresDeleteOptions) returns (Result) {} rpc StoresGet(StoresGetOptions) returns (StoresGetResult) {} rpc StoresFind(StoresFindOptions) returns (StoresFindResult) {} + + rpc Rerank(RerankRequest) returns (RerankResult) {} +} + +message RerankRequest { + string query = 1; + repeated string documents = 2; + int32 top_n = 3; +} + +message RerankResult { + Usage usage = 1; + repeated DocumentResult results = 2; +} + +message Usage { + int32 total_tokens = 1; + int32 prompt_tokens = 2; +} + +message DocumentResult { + int32 index = 1; + string text = 2; + float relevance_score = 3; } message StoresKey { diff --git a/backend/python/common-env/transformers/transformers-nvidia.yml b/backend/python/common-env/transformers/transformers-nvidia.yml index e12b5dbb..16e494c5 100644 --- a/backend/python/common-env/transformers/transformers-nvidia.yml +++ b/backend/python/common-env/transformers/transformers-nvidia.yml @@ -120,4 +120,6 @@ dependencies: - transformers>=4.38.2 # Updated Version - transformers_stream_generator==0.0.5 - xformers==0.0.23.post1 + - rerankers[transformers] + - pydantic prefix: /opt/conda/envs/transformers diff --git a/backend/python/common-env/transformers/transformers-rocm.yml b/backend/python/common-env/transformers/transformers-rocm.yml index 48fac8bf..cdefcc27 100644 --- a/backend/python/common-env/transformers/transformers-rocm.yml +++ b/backend/python/common-env/transformers/transformers-rocm.yml @@ -108,4 +108,6 @@ dependencies: - transformers>=4.38.2 # Updated Version - transformers_stream_generator==0.0.5 - xformers==0.0.23.post1 + - rerankers[transformers] + - pydantic prefix: /opt/conda/envs/transformers diff --git a/backend/python/common-env/transformers/transformers.yml b/backend/python/common-env/transformers/transformers.yml index 843b13fa..5c069dd0 100644 --- a/backend/python/common-env/transformers/transformers.yml +++ b/backend/python/common-env/transformers/transformers.yml @@ -111,5 +111,7 @@ dependencies: - vllm>=0.4.0 - transformers>=4.38.2 # Updated Version - transformers_stream_generator==0.0.5 - - xformers==0.0.23.post1 + - xformers==0.0.23.post1 + - rerankers[transformers] + - pydantic prefix: /opt/conda/envs/transformers diff --git a/backend/python/rerankers/Makefile b/backend/python/rerankers/Makefile new file mode 100644 index 00000000..f029c841 --- /dev/null +++ b/backend/python/rerankers/Makefile @@ -0,0 +1,27 @@ +.PHONY: rerankers +rerankers: protogen + $(MAKE) -C ../common-env/transformers + + +.PHONY: run +run: protogen + @echo "Running rerankers..." + bash run.sh + @echo "rerankers run." + +# It is not working well by using command line. It only6 works with IDE like VSCode. +.PHONY: test +test: protogen + @echo "Testing rerankers..." + bash test.sh + @echo "rerankers tested." + +.PHONY: protogen +protogen: backend_pb2_grpc.py backend_pb2.py + +.PHONY: protogen-clean +protogen-clean: + $(RM) backend_pb2_grpc.py backend_pb2.py + +backend_pb2_grpc.py backend_pb2.py: + python3 -m grpc_tools.protoc -I../.. --python_out=. --grpc_python_out=. backend.proto \ No newline at end of file diff --git a/backend/python/rerankers/README.md b/backend/python/rerankers/README.md new file mode 100644 index 00000000..9e73ba0a --- /dev/null +++ b/backend/python/rerankers/README.md @@ -0,0 +1,5 @@ +# Creating a separate environment for the reranker project + +``` +make reranker +``` \ No newline at end of file diff --git a/backend/python/rerankers/reranker.py b/backend/python/rerankers/reranker.py new file mode 100755 index 00000000..e1974ad5 --- /dev/null +++ b/backend/python/rerankers/reranker.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python3 +""" +Extra gRPC server for Rerankers models. +""" +from concurrent import futures + +import argparse +import signal +import sys +import os + +import time +import backend_pb2 +import backend_pb2_grpc + +import grpc + +from rerankers import Reranker + +_ONE_DAY_IN_SECONDS = 60 * 60 * 24 + +# If MAX_WORKERS are specified in the environment use it, otherwise default to 1 +MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1')) + +# Implement the BackendServicer class with the service methods +class BackendServicer(backend_pb2_grpc.BackendServicer): + """ + A gRPC servicer for the backend service. + + This class implements the gRPC methods for the backend service, including Health, LoadModel, and Embedding. + """ + def Health(self, request, context): + """ + A gRPC method that returns the health status of the backend service. + + Args: + request: A HealthRequest object that contains the request parameters. + context: A grpc.ServicerContext object that provides information about the RPC. + + Returns: + A Reply object that contains the health status of the backend service. + """ + return backend_pb2.Reply(message=bytes("OK", 'utf-8')) + + def LoadModel(self, request, context): + """ + A gRPC method that loads a model into memory. + + Args: + request: A LoadModelRequest object that contains the request parameters. + context: A grpc.ServicerContext object that provides information about the RPC. + + Returns: + A Result object that contains the result of the LoadModel operation. + """ + model_name = request.Model + try: + kwargs = {} + if request.Type != "": + kwargs['model_type'] = request.Type + if request.PipelineType != "": # Reuse the PipelineType field for language + kwargs['lang'] = request.PipelineType + self.model_name = model_name + self.model = Reranker(model_name, **kwargs) + except Exception as err: + return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") + + # Implement your logic here for the LoadModel service + # Replace this with your desired response + return backend_pb2.Result(message="Model loaded successfully", success=True) + + def Rerank(self, request, context): + documents = [] + for idx, doc in enumerate(request.documents): + documents.append(doc) + ranked_results=self.model.rank(query=request.query, docs=documents, doc_ids=list(range(len(request.documents)))) + # Prepare results to return + results = [ + backend_pb2.DocumentResult( + index=res.doc_id, + text=res.text, + relevance_score=res.score + ) for res in ranked_results.results + ] + + # Calculate the usage and total tokens + # TODO: Implement the usage calculation with reranker + total_tokens = sum(len(doc.split()) for doc in request.documents) + len(request.query.split()) + prompt_tokens = len(request.query.split()) + usage = backend_pb2.Usage(total_tokens=total_tokens, prompt_tokens=prompt_tokens) + return backend_pb2.RerankResult(usage=usage, results=results) + +def serve(address): + server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS)) + backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) + server.add_insecure_port(address) + server.start() + print("Server started. Listening on: " + address, file=sys.stderr) + + # Define the signal handler function + def signal_handler(sig, frame): + print("Received termination signal. Shutting down...") + server.stop(0) + sys.exit(0) + + # Set the signal handlers for SIGINT and SIGTERM + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + try: + while True: + time.sleep(_ONE_DAY_IN_SECONDS) + except KeyboardInterrupt: + server.stop(0) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run the gRPC server.") + parser.add_argument( + "--addr", default="localhost:50051", help="The address to bind the server to." + ) + args = parser.parse_args() + + serve(args.addr) diff --git a/backend/python/rerankers/run.sh b/backend/python/rerankers/run.sh new file mode 100755 index 00000000..16d8a0bd --- /dev/null +++ b/backend/python/rerankers/run.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +## +## A bash script wrapper that runs the reranker server with conda + +export PATH=$PATH:/opt/conda/bin + +# Activate conda environment +source activate transformers + +# get the directory where the bash script is located +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" + +python $DIR/reranker.py $@ diff --git a/backend/python/rerankers/test.sh b/backend/python/rerankers/test.sh new file mode 100755 index 00000000..75316829 --- /dev/null +++ b/backend/python/rerankers/test.sh @@ -0,0 +1,11 @@ +#!/bin/bash +## +## A bash script wrapper that runs the reranker server with conda + +# Activate conda environment +source activate transformers + +# get the directory where the bash script is located +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" + +python -m unittest $DIR/test_reranker.py \ No newline at end of file diff --git a/backend/python/rerankers/test_reranker.py b/backend/python/rerankers/test_reranker.py new file mode 100755 index 00000000..c1cf3d70 --- /dev/null +++ b/backend/python/rerankers/test_reranker.py @@ -0,0 +1,90 @@ +""" +A test script to test the gRPC service +""" +import unittest +import subprocess +import time +import backend_pb2 +import backend_pb2_grpc + +import grpc + + +class TestBackendServicer(unittest.TestCase): + """ + TestBackendServicer is the class that tests the gRPC service + """ + def setUp(self): + """ + This method sets up the gRPC service by starting the server + """ + self.service = subprocess.Popen(["python3", "reranker.py", "--addr", "localhost:50051"]) + time.sleep(10) + + def tearDown(self) -> None: + """ + This method tears down the gRPC service by terminating the server + """ + self.service.kill() + self.service.wait() + + def test_server_startup(self): + """ + This method tests if the server starts up successfully + """ + try: + self.setUp() + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + response = stub.Health(backend_pb2.HealthMessage()) + self.assertEqual(response.message, b'OK') + except Exception as err: + print(err) + self.fail("Server failed to start") + finally: + self.tearDown() + + def test_load_model(self): + """ + This method tests if the model is loaded successfully + """ + try: + self.setUp() + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + response = stub.LoadModel(backend_pb2.ModelOptions(Model="cross-encoder")) + self.assertTrue(response.success) + self.assertEqual(response.message, "Model loaded successfully") + except Exception as err: + print(err) + self.fail("LoadModel service failed") + finally: + self.tearDown() + + def test_rerank(self): + """ + This method tests if the embeddings are generated successfully + """ + try: + self.setUp() + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + request = backend_pb2.RerankRequest( + query="I love you", + documents=["I hate you", "I really like you"], + top_n=2 + ) + response = stub.LoadModel(backend_pb2.ModelOptions(Model="cross-encoder")) + self.assertTrue(response.success) + + rerank_response = stub.Rerank(request) + print(rerank_response.results[0]) + self.assertIsNotNone(rerank_response.results) + self.assertEqual(len(rerank_response.results), 2) + self.assertEqual(rerank_response.results[0].text, "I really like you") + self.assertEqual(rerank_response.results[1].text, "I hate you") + except Exception as err: + print(err) + self.fail("Reranker service failed") + finally: + self.tearDown() \ No newline at end of file diff --git a/core/backend/rerank.go b/core/backend/rerank.go new file mode 100644 index 00000000..810223aa --- /dev/null +++ b/core/backend/rerank.go @@ -0,0 +1,39 @@ +package backend + +import ( + "context" + "fmt" + + "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/pkg/grpc/proto" + model "github.com/go-skynet/LocalAI/pkg/model" +) + +func Rerank(backend, modelFile string, request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (*proto.RerankResult, error) { + bb := backend + if bb == "" { + return nil, fmt.Errorf("backend is required") + } + + grpcOpts := gRPCModelOpts(backendConfig) + + opts := modelOpts(config.BackendConfig{}, appConfig, []model.Option{ + model.WithBackendString(bb), + model.WithModel(modelFile), + model.WithContext(appConfig.Context), + model.WithAssetDir(appConfig.AssetsDestination), + model.WithLoadGRPCLoadModelOpts(grpcOpts), + }) + rerankModel, err := loader.BackendLoader(opts...) + if err != nil { + return nil, err + } + + if rerankModel == nil { + return nil, fmt.Errorf("could not load rerank model") + } + + res, err := rerankModel.Rerank(context.Background(), request) + + return res, err +} diff --git a/core/http/app.go b/core/http/app.go index 21652dd9..93eb0e20 100644 --- a/core/http/app.go +++ b/core/http/app.go @@ -194,6 +194,7 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi routes.RegisterOpenAIRoutes(app, cl, ml, appConfig, auth) routes.RegisterPagesRoutes(app, cl, ml, appConfig, auth) routes.RegisterUIRoutes(app, cl, ml, appConfig, galleryService, auth) + routes.RegisterJINARoutes(app, cl, ml, appConfig, auth) // Define a custom 404 handler // Note: keep this at the bottom! diff --git a/core/http/endpoints/jina/rerank.go b/core/http/endpoints/jina/rerank.go new file mode 100644 index 00000000..bf99367e --- /dev/null +++ b/core/http/endpoints/jina/rerank.go @@ -0,0 +1,84 @@ +package jina + +import ( + "github.com/go-skynet/LocalAI/core/backend" + "github.com/go-skynet/LocalAI/core/config" + + fiberContext "github.com/go-skynet/LocalAI/core/http/ctx" + "github.com/go-skynet/LocalAI/core/schema" + "github.com/go-skynet/LocalAI/pkg/grpc/proto" + "github.com/go-skynet/LocalAI/pkg/model" + "github.com/gofiber/fiber/v2" + "github.com/rs/zerolog/log" +) + +func JINARerankEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + req := new(schema.JINARerankRequest) + if err := c.BodyParser(req); err != nil { + return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{ + "error": "Cannot parse JSON", + }) + } + + input := new(schema.TTSRequest) + + // Get input data from the request body + if err := c.BodyParser(input); err != nil { + return err + } + + modelFile, err := fiberContext.ModelFromContext(c, ml, input.Model, false) + if err != nil { + modelFile = input.Model + log.Warn().Msgf("Model not found in context: %s", input.Model) + } + + cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath, + config.LoadOptionDebug(appConfig.Debug), + config.LoadOptionThreads(appConfig.Threads), + config.LoadOptionContextSize(appConfig.ContextSize), + config.LoadOptionF16(appConfig.F16), + ) + + if err != nil { + modelFile = input.Model + log.Warn().Msgf("Model not found in context: %s", input.Model) + } else { + modelFile = cfg.Model + } + log.Debug().Msgf("Request for model: %s", modelFile) + + if input.Backend != "" { + cfg.Backend = input.Backend + } + + request := &proto.RerankRequest{ + Query: req.Query, + TopN: int32(req.TopN), + Documents: req.Documents, + } + + results, err := backend.Rerank(cfg.Backend, modelFile, request, ml, appConfig, *cfg) + if err != nil { + return err + } + + response := &schema.JINARerankResponse{ + Model: req.Model, + } + + for _, r := range results.Results { + response.Results = append(response.Results, schema.JINADocumentResult{ + Index: int(r.Index), + Document: schema.JINAText{Text: r.Text}, + RelevanceScore: float64(r.RelevanceScore), + }) + } + + response.Usage.TotalTokens = int(results.Usage.TotalTokens) + response.Usage.PromptTokens = int(results.Usage.PromptTokens) + + return c.Status(fiber.StatusOK).JSON(response) + } +} diff --git a/core/http/routes/jina.go b/core/http/routes/jina.go new file mode 100644 index 00000000..9c32c72b --- /dev/null +++ b/core/http/routes/jina.go @@ -0,0 +1,19 @@ +package routes + +import ( + "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/core/http/endpoints/jina" + + "github.com/go-skynet/LocalAI/pkg/model" + "github.com/gofiber/fiber/v2" +) + +func RegisterJINARoutes(app *fiber.App, + cl *config.BackendConfigLoader, + ml *model.ModelLoader, + appConfig *config.ApplicationConfig, + auth func(*fiber.Ctx) error) { + + // POST endpoint to mimic the reranking + app.Post("/v1/rerank", jina.JINARerankEndpoint(cl, ml, appConfig)) +} diff --git a/core/schema/jina.go b/core/schema/jina.go new file mode 100644 index 00000000..7f80689c --- /dev/null +++ b/core/schema/jina.go @@ -0,0 +1,34 @@ +package schema + +// RerankRequest defines the structure of the request payload +type JINARerankRequest struct { + Model string `json:"model"` + Query string `json:"query"` + Documents []string `json:"documents"` + TopN int `json:"top_n"` +} + +// DocumentResult represents a single document result +type JINADocumentResult struct { + Index int `json:"index"` + Document JINAText `json:"document"` + RelevanceScore float64 `json:"relevance_score"` +} + +// Text holds the text of the document +type JINAText struct { + Text string `json:"text"` +} + +// RerankResponse defines the structure of the response payload +type JINARerankResponse struct { + Model string `json:"model"` + Usage JINAUsageInfo `json:"usage"` + Results []JINADocumentResult `json:"results"` +} + +// UsageInfo holds information about usage of tokens +type JINAUsageInfo struct { + TotalTokens int `json:"total_tokens"` + PromptTokens int `json:"prompt_tokens"` +} diff --git a/pkg/grpc/backend.go b/pkg/grpc/backend.go index 8fb8c39d..bef9e186 100644 --- a/pkg/grpc/backend.go +++ b/pkg/grpc/backend.go @@ -49,4 +49,6 @@ type Backend interface { StoresDelete(ctx context.Context, in *pb.StoresDeleteOptions, opts ...grpc.CallOption) (*pb.Result, error) StoresGet(ctx context.Context, in *pb.StoresGetOptions, opts ...grpc.CallOption) (*pb.StoresGetResult, error) StoresFind(ctx context.Context, in *pb.StoresFindOptions, opts ...grpc.CallOption) (*pb.StoresFindResult, error) + + Rerank(ctx context.Context, in *pb.RerankRequest, opts ...grpc.CallOption) (*pb.RerankResult, error) } diff --git a/pkg/grpc/client.go b/pkg/grpc/client.go index 882db12a..fc4a12fa 100644 --- a/pkg/grpc/client.go +++ b/pkg/grpc/client.go @@ -355,3 +355,19 @@ func (c *Client) StoresFind(ctx context.Context, in *pb.StoresFindOptions, opts client := pb.NewBackendClient(conn) return client.StoresFind(ctx, in, opts...) } + +func (c *Client) Rerank(ctx context.Context, in *pb.RerankRequest, opts ...grpc.CallOption) (*pb.RerankResult, error) { + if !c.parallel { + c.opMutex.Lock() + defer c.opMutex.Unlock() + } + c.setBusy(true) + defer c.setBusy(false) + conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + return nil, err + } + defer conn.Close() + client := pb.NewBackendClient(conn) + return client.Rerank(ctx, in, opts...) +} diff --git a/pkg/grpc/embed.go b/pkg/grpc/embed.go index 73b185a3..694e83b0 100644 --- a/pkg/grpc/embed.go +++ b/pkg/grpc/embed.go @@ -101,6 +101,10 @@ func (e *embedBackend) StoresFind(ctx context.Context, in *pb.StoresFindOptions, return e.s.StoresFind(ctx, in) } +func (e *embedBackend) Rerank(ctx context.Context, in *pb.RerankRequest, opts ...grpc.CallOption) (*pb.RerankResult, error) { + return e.s.Rerank(ctx, in) +} + type embedBackendServerStream struct { ctx context.Context fn func(s []byte)