diff --git a/backend/python/transformers/transformers_server.py b/backend/python/transformers/transformers_server.py index 1b38a956..90053ed5 100755 --- a/backend/python/transformers/transformers_server.py +++ b/backend/python/transformers/transformers_server.py @@ -159,6 +159,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): quantization_config=quantization, device_map=device_map, torch_dtype=compute) + if request.ContextSize > 0: + self.max_tokens = request.ContextSize + else: + self.max_tokens = self.model.config.max_position_embeddings + self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_safetensors=True) self.XPU = False @@ -217,10 +222,6 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): if request.TopK == 0: request.TopK = 40 - max_tokens = 200 - if request.Tokens > 0: - max_tokens = request.Tokens - prompt = request.Prompt if not request.Prompt and request.UseTokenizerTemplate and request.Messages: prompt = self.tokenizer.apply_chat_template(request.Messages, tokenize=False, add_generation_prompt=True) @@ -232,6 +233,12 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): eos_token_id.append(self.tokenizer.convert_tokens_to_ids(word)) inputs = self.tokenizer(prompt, return_tensors="pt") + + if request.Tokens > 0: + max_tokens = request.Tokens + else: + max_tokens = self.max_tokens - inputs["input_ids"].size()[inputs["input_ids"].dim()-1] + if self.CUDA: inputs = inputs.to("cuda") if XPU and self.OV == False: