diff --git a/extra/grpc/vllm/backend_vllm.py b/extra/grpc/vllm/backend_vllm.py index 4b884f6f..a35cbc74 100644 --- a/extra/grpc/vllm/backend_vllm.py +++ b/extra/grpc/vllm/backend_vllm.py @@ -49,11 +49,13 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): return backend_pb2.Result(message="Model loaded successfully", success=True) def Predict(self, request, context): + if request.TopP == 0: + request.TopP = 0.9 + sampling_params = SamplingParams(temperature=request.Temperature, top_p=request.TopP) outputs = self.llm.generate([request.Prompt], sampling_params) generated_text = outputs[0].outputs[0].text - # Remove prompt from response if present if request.Prompt in generated_text: generated_text = generated_text.replace(request.Prompt, "")