diff --git a/Dockerfile b/Dockerfile index 16344473..c7da9f63 100644 --- a/Dockerfile +++ b/Dockerfile @@ -12,7 +12,7 @@ ARG TARGETARCH ARG TARGETVARIANT ENV BUILD_TYPE=${BUILD_TYPE} -ENV EXTERNAL_GRPC_BACKENDS="huggingface-embeddings:/build/backend/python/sentencetransformers/run.sh,transformers:/build/backend/python/transformers/run.sh,sentencetransformers:/build/backend/python/sentencetransformers/run.sh,autogptq:/build/backend/python/autogptq/run.sh,bark:/build/backend/python/bark/run.sh,diffusers:/build/backend/python/diffusers/run.sh,exllama:/build/backend/python/exllama/run.sh,vall-e-x:/build/backend/python/vall-e-x/run.sh,vllm:/build/backend/python/vllm/run.sh" +ENV EXTERNAL_GRPC_BACKENDS="huggingface-embeddings:/build/backend/python/sentencetransformers/run.sh,petals:/build/backend/python/petals/run.sh,transformers:/build/backend/python/transformers/run.sh,sentencetransformers:/build/backend/python/sentencetransformers/run.sh,autogptq:/build/backend/python/autogptq/run.sh,bark:/build/backend/python/bark/run.sh,diffusers:/build/backend/python/diffusers/run.sh,exllama:/build/backend/python/exllama/run.sh,vall-e-x:/build/backend/python/vall-e-x/run.sh,vllm:/build/backend/python/vllm/run.sh" ENV GALLERIES='[{"name":"model-gallery", "url":"github:go-skynet/model-gallery/index.yaml"}, {"url": "github:go-skynet/model-gallery/huggingface.yaml","name":"huggingface"}]' ARG GO_TAGS="stablediffusion tts" @@ -181,13 +181,18 @@ RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \ RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \ PATH=$PATH:/opt/conda/bin make -C backend/python/exllama \ ; fi +RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \ + PATH=$PATH:/opt/conda/bin make -C backend/python/petals \ + ; fi # Copy VALLE-X as it's not a real "lib" +# TODO: this is wrong - we should copy the lib into the conda env path RUN if [ -d /usr/lib/vall-e-x ]; then \ cp -rfv /usr/lib/vall-e-x/* ./ ; \ fi # we also copy exllama libs over to resolve exllama import error +# TODO: check if this is still needed RUN if [ -d /usr/local/lib/python3.9/dist-packages/exllama ]; then \ cp -rfv /usr/local/lib/python3.9/dist-packages/exllama backend/python/exllama/;\ fi diff --git a/Makefile b/Makefile index c46e4fc3..22689eb3 100644 --- a/Makefile +++ b/Makefile @@ -388,6 +388,7 @@ protogen-python: python3 -m grpc_tools.protoc -Ibackend/ --python_out=backend/python/diffusers/ --grpc_python_out=backend/python/diffusers/ backend/backend.proto python3 -m grpc_tools.protoc -Ibackend/ --python_out=backend/python/vall-e-x/ --grpc_python_out=backend/python/vall-e-x/ backend/backend.proto python3 -m grpc_tools.protoc -Ibackend/ --python_out=backend/python/vllm/ --grpc_python_out=backend/python/vllm/ backend/backend.proto + python3 -m grpc_tools.protoc -Ibackend/ --python_out=backend/python/petals/ --grpc_python_out=backend/python/petals/ backend/backend.proto ## GRPC # Note: it is duplicated in the Dockerfile @@ -400,6 +401,7 @@ prepare-extra-conda-environments: $(MAKE) -C backend/python/transformers $(MAKE) -C backend/python/vall-e-x $(MAKE) -C backend/python/exllama + $(MAKE) -C backend/python/petals backend-assets/grpc: diff --git a/backend/python/petals/Makefile b/backend/python/petals/Makefile new file mode 100644 index 00000000..db71a175 --- /dev/null +++ b/backend/python/petals/Makefile @@ -0,0 +1,11 @@ +.PONY: petals +petals: + @echo "Creating virtual environment..." + @conda env create --name petals --file petals.yml + @echo "Virtual environment created." + +.PONY: run +run: + @echo "Running petals..." + bash run.sh + @echo "petals run." \ No newline at end of file diff --git a/backend/python/petals/backend_pb2.py b/backend/python/petals/backend_pb2.py new file mode 100644 index 00000000..9b89eb2a --- /dev/null +++ b/backend/python/petals/backend_pb2.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: backend.proto +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rbackend.proto\x12\x07\x62\x61\x63kend\"\x0f\n\rHealthMessage\"\xa6\x06\n\x0ePredictOptions\x12\x0e\n\x06Prompt\x18\x01 \x01(\t\x12\x0c\n\x04Seed\x18\x02 \x01(\x05\x12\x0f\n\x07Threads\x18\x03 \x01(\x05\x12\x0e\n\x06Tokens\x18\x04 \x01(\x05\x12\x0c\n\x04TopK\x18\x05 \x01(\x05\x12\x0e\n\x06Repeat\x18\x06 \x01(\x05\x12\r\n\x05\x42\x61tch\x18\x07 \x01(\x05\x12\r\n\x05NKeep\x18\x08 \x01(\x05\x12\x13\n\x0bTemperature\x18\t \x01(\x02\x12\x0f\n\x07Penalty\x18\n \x01(\x02\x12\r\n\x05\x46\x31\x36KV\x18\x0b \x01(\x08\x12\x11\n\tDebugMode\x18\x0c \x01(\x08\x12\x13\n\x0bStopPrompts\x18\r \x03(\t\x12\x11\n\tIgnoreEOS\x18\x0e \x01(\x08\x12\x19\n\x11TailFreeSamplingZ\x18\x0f \x01(\x02\x12\x10\n\x08TypicalP\x18\x10 \x01(\x02\x12\x18\n\x10\x46requencyPenalty\x18\x11 \x01(\x02\x12\x17\n\x0fPresencePenalty\x18\x12 \x01(\x02\x12\x10\n\x08Mirostat\x18\x13 \x01(\x05\x12\x13\n\x0bMirostatETA\x18\x14 \x01(\x02\x12\x13\n\x0bMirostatTAU\x18\x15 \x01(\x02\x12\x12\n\nPenalizeNL\x18\x16 \x01(\x08\x12\x11\n\tLogitBias\x18\x17 \x01(\t\x12\r\n\x05MLock\x18\x19 \x01(\x08\x12\x0c\n\x04MMap\x18\x1a \x01(\x08\x12\x16\n\x0ePromptCacheAll\x18\x1b \x01(\x08\x12\x15\n\rPromptCacheRO\x18\x1c \x01(\x08\x12\x0f\n\x07Grammar\x18\x1d \x01(\t\x12\x0f\n\x07MainGPU\x18\x1e \x01(\t\x12\x13\n\x0bTensorSplit\x18\x1f \x01(\t\x12\x0c\n\x04TopP\x18 \x01(\x02\x12\x17\n\x0fPromptCachePath\x18! \x01(\t\x12\r\n\x05\x44\x65\x62ug\x18\" \x01(\x08\x12\x17\n\x0f\x45mbeddingTokens\x18# \x03(\x05\x12\x12\n\nEmbeddings\x18$ \x01(\t\x12\x14\n\x0cRopeFreqBase\x18% \x01(\x02\x12\x15\n\rRopeFreqScale\x18& \x01(\x02\x12\x1b\n\x13NegativePromptScale\x18\' \x01(\x02\x12\x16\n\x0eNegativePrompt\x18( \x01(\t\x12\x0e\n\x06NDraft\x18) \x01(\x05\x12\x0e\n\x06Images\x18* \x03(\t\"\x18\n\x05Reply\x12\x0f\n\x07message\x18\x01 \x01(\x0c\"\x99\x07\n\x0cModelOptions\x12\r\n\x05Model\x18\x01 \x01(\t\x12\x13\n\x0b\x43ontextSize\x18\x02 \x01(\x05\x12\x0c\n\x04Seed\x18\x03 \x01(\x05\x12\x0e\n\x06NBatch\x18\x04 \x01(\x05\x12\x11\n\tF16Memory\x18\x05 \x01(\x08\x12\r\n\x05MLock\x18\x06 \x01(\x08\x12\x0c\n\x04MMap\x18\x07 \x01(\x08\x12\x11\n\tVocabOnly\x18\x08 \x01(\x08\x12\x0f\n\x07LowVRAM\x18\t \x01(\x08\x12\x12\n\nEmbeddings\x18\n \x01(\x08\x12\x0c\n\x04NUMA\x18\x0b \x01(\x08\x12\x12\n\nNGPULayers\x18\x0c \x01(\x05\x12\x0f\n\x07MainGPU\x18\r \x01(\t\x12\x13\n\x0bTensorSplit\x18\x0e \x01(\t\x12\x0f\n\x07Threads\x18\x0f \x01(\x05\x12\x19\n\x11LibrarySearchPath\x18\x10 \x01(\t\x12\x14\n\x0cRopeFreqBase\x18\x11 \x01(\x02\x12\x15\n\rRopeFreqScale\x18\x12 \x01(\x02\x12\x12\n\nRMSNormEps\x18\x13 \x01(\x02\x12\x0c\n\x04NGQA\x18\x14 \x01(\x05\x12\x11\n\tModelFile\x18\x15 \x01(\t\x12\x0e\n\x06\x44\x65vice\x18\x16 \x01(\t\x12\x11\n\tUseTriton\x18\x17 \x01(\x08\x12\x15\n\rModelBaseName\x18\x18 \x01(\t\x12\x18\n\x10UseFastTokenizer\x18\x19 \x01(\x08\x12\x14\n\x0cPipelineType\x18\x1a \x01(\t\x12\x15\n\rSchedulerType\x18\x1b \x01(\t\x12\x0c\n\x04\x43UDA\x18\x1c \x01(\x08\x12\x10\n\x08\x43\x46GScale\x18\x1d \x01(\x02\x12\x0f\n\x07IMG2IMG\x18\x1e \x01(\x08\x12\x11\n\tCLIPModel\x18\x1f \x01(\t\x12\x15\n\rCLIPSubfolder\x18 \x01(\t\x12\x10\n\x08\x43LIPSkip\x18! \x01(\x05\x12\x11\n\tTokenizer\x18\" \x01(\t\x12\x10\n\x08LoraBase\x18# \x01(\t\x12\x13\n\x0bLoraAdapter\x18$ \x01(\t\x12\x11\n\tLoraScale\x18* \x01(\x02\x12\x11\n\tNoMulMatQ\x18% \x01(\x08\x12\x12\n\nDraftModel\x18\' \x01(\t\x12\x11\n\tAudioPath\x18& \x01(\t\x12\x14\n\x0cQuantization\x18( \x01(\t\x12\x0e\n\x06MMProj\x18) \x01(\t\x12\x13\n\x0bRopeScaling\x18+ \x01(\t\x12\x15\n\rYarnExtFactor\x18, \x01(\x02\x12\x16\n\x0eYarnAttnFactor\x18- \x01(\x02\x12\x14\n\x0cYarnBetaFast\x18. \x01(\x02\x12\x14\n\x0cYarnBetaSlow\x18/ \x01(\x02\"*\n\x06Result\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\"%\n\x0f\x45mbeddingResult\x12\x12\n\nembeddings\x18\x01 \x03(\x02\"C\n\x11TranscriptRequest\x12\x0b\n\x03\x64st\x18\x02 \x01(\t\x12\x10\n\x08language\x18\x03 \x01(\t\x12\x0f\n\x07threads\x18\x04 \x01(\r\"N\n\x10TranscriptResult\x12,\n\x08segments\x18\x01 \x03(\x0b\x32\x1a.backend.TranscriptSegment\x12\x0c\n\x04text\x18\x02 \x01(\t\"Y\n\x11TranscriptSegment\x12\n\n\x02id\x18\x01 \x01(\x05\x12\r\n\x05start\x18\x02 \x01(\x03\x12\x0b\n\x03\x65nd\x18\x03 \x01(\x03\x12\x0c\n\x04text\x18\x04 \x01(\t\x12\x0e\n\x06tokens\x18\x05 \x03(\x05\"\xd7\x01\n\x14GenerateImageRequest\x12\x0e\n\x06height\x18\x01 \x01(\x05\x12\r\n\x05width\x18\x02 \x01(\x05\x12\x0c\n\x04mode\x18\x03 \x01(\x05\x12\x0c\n\x04step\x18\x04 \x01(\x05\x12\x0c\n\x04seed\x18\x05 \x01(\x05\x12\x17\n\x0fpositive_prompt\x18\x06 \x01(\t\x12\x17\n\x0fnegative_prompt\x18\x07 \x01(\t\x12\x0b\n\x03\x64st\x18\x08 \x01(\t\x12\x0b\n\x03src\x18\t \x01(\t\x12\x18\n\x10\x45nableParameters\x18\n \x01(\t\x12\x10\n\x08\x43LIPSkip\x18\x0b \x01(\x05\"6\n\nTTSRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\r\n\x05model\x18\x02 \x01(\t\x12\x0b\n\x03\x64st\x18\x03 \x01(\t\"6\n\x14TokenizationResponse\x12\x0e\n\x06length\x18\x01 \x01(\x05\x12\x0e\n\x06tokens\x18\x02 \x03(\x05\"\x8e\x01\n\x0fMemoryUsageData\x12\r\n\x05total\x18\x01 \x01(\x04\x12:\n\tbreakdown\x18\x02 \x03(\x0b\x32\'.backend.MemoryUsageData.BreakdownEntry\x1a\x30\n\x0e\x42reakdownEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x04:\x02\x38\x01\"\xad\x01\n\x0eStatusResponse\x12,\n\x05state\x18\x01 \x01(\x0e\x32\x1d.backend.StatusResponse.State\x12(\n\x06memory\x18\x02 \x01(\x0b\x32\x18.backend.MemoryUsageData\"C\n\x05State\x12\x11\n\rUNINITIALIZED\x10\x00\x12\x08\n\x04\x42USY\x10\x01\x12\t\n\x05READY\x10\x02\x12\x12\n\x05\x45RROR\x10\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01\x32\xf4\x04\n\x07\x42\x61\x63kend\x12\x32\n\x06Health\x12\x16.backend.HealthMessage\x1a\x0e.backend.Reply\"\x00\x12\x34\n\x07Predict\x12\x17.backend.PredictOptions\x1a\x0e.backend.Reply\"\x00\x12\x35\n\tLoadModel\x12\x15.backend.ModelOptions\x1a\x0f.backend.Result\"\x00\x12<\n\rPredictStream\x12\x17.backend.PredictOptions\x1a\x0e.backend.Reply\"\x00\x30\x01\x12@\n\tEmbedding\x12\x17.backend.PredictOptions\x1a\x18.backend.EmbeddingResult\"\x00\x12\x41\n\rGenerateImage\x12\x1d.backend.GenerateImageRequest\x1a\x0f.backend.Result\"\x00\x12M\n\x12\x41udioTranscription\x12\x1a.backend.TranscriptRequest\x1a\x19.backend.TranscriptResult\"\x00\x12-\n\x03TTS\x12\x13.backend.TTSRequest\x1a\x0f.backend.Result\"\x00\x12J\n\x0eTokenizeString\x12\x17.backend.PredictOptions\x1a\x1d.backend.TokenizationResponse\"\x00\x12;\n\x06Status\x12\x16.backend.HealthMessage\x1a\x17.backend.StatusResponse\"\x00\x42Z\n\x19io.skynet.localai.backendB\x0eLocalAIBackendP\x01Z+github.com/go-skynet/LocalAI/pkg/grpc/protob\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'backend_pb2', _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b'\n\031io.skynet.localai.backendB\016LocalAIBackendP\001Z+github.com/go-skynet/LocalAI/pkg/grpc/proto' + _MEMORYUSAGEDATA_BREAKDOWNENTRY._options = None + _MEMORYUSAGEDATA_BREAKDOWNENTRY._serialized_options = b'8\001' + _globals['_HEALTHMESSAGE']._serialized_start=26 + _globals['_HEALTHMESSAGE']._serialized_end=41 + _globals['_PREDICTOPTIONS']._serialized_start=44 + _globals['_PREDICTOPTIONS']._serialized_end=850 + _globals['_REPLY']._serialized_start=852 + _globals['_REPLY']._serialized_end=876 + _globals['_MODELOPTIONS']._serialized_start=879 + _globals['_MODELOPTIONS']._serialized_end=1800 + _globals['_RESULT']._serialized_start=1802 + _globals['_RESULT']._serialized_end=1844 + _globals['_EMBEDDINGRESULT']._serialized_start=1846 + _globals['_EMBEDDINGRESULT']._serialized_end=1883 + _globals['_TRANSCRIPTREQUEST']._serialized_start=1885 + _globals['_TRANSCRIPTREQUEST']._serialized_end=1952 + _globals['_TRANSCRIPTRESULT']._serialized_start=1954 + _globals['_TRANSCRIPTRESULT']._serialized_end=2032 + _globals['_TRANSCRIPTSEGMENT']._serialized_start=2034 + _globals['_TRANSCRIPTSEGMENT']._serialized_end=2123 + _globals['_GENERATEIMAGEREQUEST']._serialized_start=2126 + _globals['_GENERATEIMAGEREQUEST']._serialized_end=2341 + _globals['_TTSREQUEST']._serialized_start=2343 + _globals['_TTSREQUEST']._serialized_end=2397 + _globals['_TOKENIZATIONRESPONSE']._serialized_start=2399 + _globals['_TOKENIZATIONRESPONSE']._serialized_end=2453 + _globals['_MEMORYUSAGEDATA']._serialized_start=2456 + _globals['_MEMORYUSAGEDATA']._serialized_end=2598 + _globals['_MEMORYUSAGEDATA_BREAKDOWNENTRY']._serialized_start=2550 + _globals['_MEMORYUSAGEDATA_BREAKDOWNENTRY']._serialized_end=2598 + _globals['_STATUSRESPONSE']._serialized_start=2601 + _globals['_STATUSRESPONSE']._serialized_end=2774 + _globals['_STATUSRESPONSE_STATE']._serialized_start=2707 + _globals['_STATUSRESPONSE_STATE']._serialized_end=2774 + _globals['_BACKEND']._serialized_start=2777 + _globals['_BACKEND']._serialized_end=3405 +# @@protoc_insertion_point(module_scope) diff --git a/backend/python/petals/backend_pb2_grpc.py b/backend/python/petals/backend_pb2_grpc.py new file mode 100644 index 00000000..79a7677f --- /dev/null +++ b/backend/python/petals/backend_pb2_grpc.py @@ -0,0 +1,363 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +import backend_pb2 as backend__pb2 + + +class BackendStub(object): + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.Health = channel.unary_unary( + '/backend.Backend/Health', + request_serializer=backend__pb2.HealthMessage.SerializeToString, + response_deserializer=backend__pb2.Reply.FromString, + ) + self.Predict = channel.unary_unary( + '/backend.Backend/Predict', + request_serializer=backend__pb2.PredictOptions.SerializeToString, + response_deserializer=backend__pb2.Reply.FromString, + ) + self.LoadModel = channel.unary_unary( + '/backend.Backend/LoadModel', + request_serializer=backend__pb2.ModelOptions.SerializeToString, + response_deserializer=backend__pb2.Result.FromString, + ) + self.PredictStream = channel.unary_stream( + '/backend.Backend/PredictStream', + request_serializer=backend__pb2.PredictOptions.SerializeToString, + response_deserializer=backend__pb2.Reply.FromString, + ) + self.Embedding = channel.unary_unary( + '/backend.Backend/Embedding', + request_serializer=backend__pb2.PredictOptions.SerializeToString, + response_deserializer=backend__pb2.EmbeddingResult.FromString, + ) + self.GenerateImage = channel.unary_unary( + '/backend.Backend/GenerateImage', + request_serializer=backend__pb2.GenerateImageRequest.SerializeToString, + response_deserializer=backend__pb2.Result.FromString, + ) + self.AudioTranscription = channel.unary_unary( + '/backend.Backend/AudioTranscription', + request_serializer=backend__pb2.TranscriptRequest.SerializeToString, + response_deserializer=backend__pb2.TranscriptResult.FromString, + ) + self.TTS = channel.unary_unary( + '/backend.Backend/TTS', + request_serializer=backend__pb2.TTSRequest.SerializeToString, + response_deserializer=backend__pb2.Result.FromString, + ) + self.TokenizeString = channel.unary_unary( + '/backend.Backend/TokenizeString', + request_serializer=backend__pb2.PredictOptions.SerializeToString, + response_deserializer=backend__pb2.TokenizationResponse.FromString, + ) + self.Status = channel.unary_unary( + '/backend.Backend/Status', + request_serializer=backend__pb2.HealthMessage.SerializeToString, + response_deserializer=backend__pb2.StatusResponse.FromString, + ) + + +class BackendServicer(object): + """Missing associated documentation comment in .proto file.""" + + def Health(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Predict(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def LoadModel(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def PredictStream(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Embedding(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GenerateImage(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def AudioTranscription(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def TTS(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def TokenizeString(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Status(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_BackendServicer_to_server(servicer, server): + rpc_method_handlers = { + 'Health': grpc.unary_unary_rpc_method_handler( + servicer.Health, + request_deserializer=backend__pb2.HealthMessage.FromString, + response_serializer=backend__pb2.Reply.SerializeToString, + ), + 'Predict': grpc.unary_unary_rpc_method_handler( + servicer.Predict, + request_deserializer=backend__pb2.PredictOptions.FromString, + response_serializer=backend__pb2.Reply.SerializeToString, + ), + 'LoadModel': grpc.unary_unary_rpc_method_handler( + servicer.LoadModel, + request_deserializer=backend__pb2.ModelOptions.FromString, + response_serializer=backend__pb2.Result.SerializeToString, + ), + 'PredictStream': grpc.unary_stream_rpc_method_handler( + servicer.PredictStream, + request_deserializer=backend__pb2.PredictOptions.FromString, + response_serializer=backend__pb2.Reply.SerializeToString, + ), + 'Embedding': grpc.unary_unary_rpc_method_handler( + servicer.Embedding, + request_deserializer=backend__pb2.PredictOptions.FromString, + response_serializer=backend__pb2.EmbeddingResult.SerializeToString, + ), + 'GenerateImage': grpc.unary_unary_rpc_method_handler( + servicer.GenerateImage, + request_deserializer=backend__pb2.GenerateImageRequest.FromString, + response_serializer=backend__pb2.Result.SerializeToString, + ), + 'AudioTranscription': grpc.unary_unary_rpc_method_handler( + servicer.AudioTranscription, + request_deserializer=backend__pb2.TranscriptRequest.FromString, + response_serializer=backend__pb2.TranscriptResult.SerializeToString, + ), + 'TTS': grpc.unary_unary_rpc_method_handler( + servicer.TTS, + request_deserializer=backend__pb2.TTSRequest.FromString, + response_serializer=backend__pb2.Result.SerializeToString, + ), + 'TokenizeString': grpc.unary_unary_rpc_method_handler( + servicer.TokenizeString, + request_deserializer=backend__pb2.PredictOptions.FromString, + response_serializer=backend__pb2.TokenizationResponse.SerializeToString, + ), + 'Status': grpc.unary_unary_rpc_method_handler( + servicer.Status, + request_deserializer=backend__pb2.HealthMessage.FromString, + response_serializer=backend__pb2.StatusResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'backend.Backend', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + + # This class is part of an EXPERIMENTAL API. +class Backend(object): + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def Health(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/backend.Backend/Health', + backend__pb2.HealthMessage.SerializeToString, + backend__pb2.Reply.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def Predict(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/backend.Backend/Predict', + backend__pb2.PredictOptions.SerializeToString, + backend__pb2.Reply.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def LoadModel(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/backend.Backend/LoadModel', + backend__pb2.ModelOptions.SerializeToString, + backend__pb2.Result.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def PredictStream(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream(request, target, '/backend.Backend/PredictStream', + backend__pb2.PredictOptions.SerializeToString, + backend__pb2.Reply.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def Embedding(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/backend.Backend/Embedding', + backend__pb2.PredictOptions.SerializeToString, + backend__pb2.EmbeddingResult.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def GenerateImage(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/backend.Backend/GenerateImage', + backend__pb2.GenerateImageRequest.SerializeToString, + backend__pb2.Result.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def AudioTranscription(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/backend.Backend/AudioTranscription', + backend__pb2.TranscriptRequest.SerializeToString, + backend__pb2.TranscriptResult.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def TTS(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/backend.Backend/TTS', + backend__pb2.TTSRequest.SerializeToString, + backend__pb2.Result.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def TokenizeString(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/backend.Backend/TokenizeString', + backend__pb2.PredictOptions.SerializeToString, + backend__pb2.TokenizationResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def Status(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/backend.Backend/Status', + backend__pb2.HealthMessage.SerializeToString, + backend__pb2.StatusResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/backend/python/petals/backend_petals.py b/backend/python/petals/backend_petals.py new file mode 100755 index 00000000..73bcc4a0 --- /dev/null +++ b/backend/python/petals/backend_petals.py @@ -0,0 +1,140 @@ +#!/usr/bin/env python3 +from concurrent import futures +import time +import argparse +import signal +import sys +import os + +import backend_pb2 +import backend_pb2_grpc + +import grpc +import torch +from transformers import AutoTokenizer +from petals import AutoDistributedModelForCausalLM + +_ONE_DAY_IN_SECONDS = 60 * 60 * 24 + +# If MAX_WORKERS are specified in the environment use it, otherwise default to 1 +MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1')) + +# Implement the BackendServicer class with the service methods +class BackendServicer(backend_pb2_grpc.BackendServicer): + """ + A gRPC servicer that implements the Backend service defined in backend.proto. + """ + def Health(self, request, context): + """ + Returns a health check message. + + Args: + request: The health check request. + context: The gRPC context. + + Returns: + backend_pb2.Reply: The health check reply. + """ + return backend_pb2.Reply(message=bytes("OK", 'utf-8')) + + def LoadModel(self, request, context): + """ + Loads a language model. + + Args: + request: The load model request. + context: The gRPC context. + + Returns: + backend_pb2.Result: The load model result. + """ + try: + self.tokenizer = AutoTokenizer.from_pretrained(request.Model, use_fast=False, add_bos_token=False) + self.model = AutoDistributedModelForCausalLM.from_pretrained(request.Model) + self.cuda = False + if request.CUDA: + self.model = self.model.cuda() + self.cuda = True + + except Exception as err: + return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") + return backend_pb2.Result(message="Model loaded successfully", success=True) + + def Predict(self, request, context): + """ + Generates text based on the given prompt and sampling parameters. + + Args: + request: The predict request. + context: The gRPC context. + + Returns: + backend_pb2.Result: The predict result. + """ + + inputs = self.tokenizer(request.Prompt, return_tensors="pt")["input_ids"] + if self.cuda: + inputs = inputs.cuda() + + if request.Tokens == 0: + # Max to max value if tokens are not specified + request.Tokens = 8192 + + # TODO: kwargs and map all parameters + outputs = self.model.generate(inputs, max_new_tokens=request.Tokens) + + generated_text = self.tokenizer.decode(outputs[0]) + # Remove prompt from response if present + if request.Prompt in generated_text: + generated_text = generated_text.replace(request.Prompt, "") + + return backend_pb2.Result(message=bytes(generated_text, encoding='utf-8')) + + def PredictStream(self, request, context): + """ + Generates text based on the given prompt and sampling parameters, and streams the results. + + Args: + request: The predict stream request. + context: The gRPC context. + + Returns: + backend_pb2.Result: The predict stream result. + """ + # Implement PredictStream RPC + #for reply in some_data_generator(): + # yield reply + # Not implemented yet + return self.Predict(request, context) + +def serve(address): + server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS)) + backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) + server.add_insecure_port(address) + server.start() + print("Server started. Listening on: " + address, file=sys.stderr) + + # Define the signal handler function + def signal_handler(sig, frame): + print("Received termination signal. Shutting down...") + server.stop(0) + sys.exit(0) + + # Set the signal handlers for SIGINT and SIGTERM + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + try: + while True: + time.sleep(_ONE_DAY_IN_SECONDS) + except KeyboardInterrupt: + server.stop(0) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run the gRPC server.") + parser.add_argument( + "--addr", default="localhost:50051", help="The address to bind the server to." + ) + args = parser.parse_args() + + serve(args.addr) diff --git a/backend/python/petals/petals.yml b/backend/python/petals/petals.yml new file mode 100644 index 00000000..a68f39b4 --- /dev/null +++ b/backend/python/petals/petals.yml @@ -0,0 +1,29 @@ +name: petals +channels: + - defaults +dependencies: + # - _libgcc_mutex=0.1=main + # - _openmp_mutex=5.1=1_gnu + # - bzip2=1.0.8=h7b6447c_0 + # - ca-certificates=2023.08.22=h06a4308_0 + # - ld_impl_linux-64=2.38=h1181459_1 + # - libffi=3.4.4=h6a678d5_0 + # - libgcc-ng=11.2.0=h1234567_1 + # - libgomp=11.2.0=h1234567_1 + # - libstdcxx-ng=11.2.0=h1234567_1 + # - libuuid=1.41.5=h5eee18b_0 + # - ncurses=6.4=h6a678d5_0 + # - openssl=3.0.11=h7f8727e_2 + # - pip=23.2.1=py311h06a4308_0 + # - python=3.11.5=h955ad1f_0 + # - readline=8.2=h5eee18b_0 + # - setuptools=68.0.0=py311h06a4308_0 + # - sqlite=3.41.2=h5eee18b_0 + # - tk=8.6.12=h1ccaba5_0 + # - tzdata=2023c=h04d1e81_0 + # - wheel=0.41.2=py311h06a4308_0 + # - xz=5.4.2=h5eee18b_0 + # - zlib=1.2.13=h5eee18b_0 + - pip: + - git+https://github.com/bigscience-workshop/petals +prefix: /opt/conda/envs/petals diff --git a/backend/python/petals/run.sh b/backend/python/petals/run.sh new file mode 100755 index 00000000..64a1a66f --- /dev/null +++ b/backend/python/petals/run.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +## +## A bash script wrapper that runs the exllama server with conda + +export PATH=$PATH:/opt/conda/bin + +# Activate conda environment +# if source is available use it, or use conda +# +if [ -f /opt/conda/bin/activate ]; then + source activate petals +else + eval "$(conda shell.bash hook)" + conda activate petals +fi + +# get the directory where the bash script is located +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" + +python $DIR/backend_petals.py $@