diff --git a/backend/python/exllama2/exllama2_backend.py b/backend/python/exllama2/exllama2_backend.py index af3e2970..cb21ed7e 100755 --- a/backend/python/exllama2/exllama2_backend.py +++ b/backend/python/exllama2/exllama2_backend.py @@ -7,7 +7,8 @@ import backend_pb2_grpc import argparse import signal import sys -import os, glob +import os +import glob from pathlib import Path import torch @@ -21,7 +22,7 @@ from exllamav2.generator import ( ) -from exllamav2 import( +from exllamav2 import ( ExLlamaV2, ExLlamaV2Config, ExLlamaV2Cache, @@ -40,6 +41,7 @@ MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1')) class BackendServicer(backend_pb2_grpc.BackendServicer): def Health(self, request, context): return backend_pb2.Reply(message=bytes("OK", 'utf-8')) + def LoadModel(self, request, context): try: model_directory = request.ModelFile @@ -50,7 +52,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): model = ExLlamaV2(config) - cache = ExLlamaV2Cache(model, lazy = True) + cache = ExLlamaV2Cache(model, lazy=True) model.load_autosplit(cache) tokenizer = ExLlamaV2Tokenizer(config) @@ -59,7 +61,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): generator = ExLlamaV2BaseGenerator(model, cache, tokenizer) - self.generator= generator + self.generator = generator generator.warmup() self.model = model @@ -85,17 +87,18 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): if request.Tokens != 0: tokens = request.Tokens - output = self.generator.generate_simple(request.Prompt, settings, tokens, seed = self.seed) + output = self.generator.generate_simple( + request.Prompt, settings, tokens) # Remove prompt from response if present if request.Prompt in output: output = output.replace(request.Prompt, "") - return backend_pb2.Result(message=bytes(t, encoding='utf-8')) + return backend_pb2.Result(message=bytes(output, encoding='utf-8')) def PredictStream(self, request, context): # Implement PredictStream RPC - #for reply in some_data_generator(): + # for reply in some_data_generator(): # yield reply # Not implemented yet return self.Predict(request, context) @@ -124,6 +127,7 @@ def serve(address): except KeyboardInterrupt: server.stop(0) + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run the gRPC server.") parser.add_argument( @@ -131,4 +135,4 @@ if __name__ == "__main__": ) args = parser.parse_args() - serve(args.addr) \ No newline at end of file + serve(args.addr)