feat: llama.cpp gRPC C++ backend (#1170)

* wip: llama.cpp c++ gRPC server

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* make it work, attach it to the build process

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* update deps

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* fix: add protobuf dep

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* try fix protobuf on cmake

* cmake: workarounds

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* add packages

* cmake: use fixed version of grpc

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* cmake(grpc): install locally

* install grpc

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* install required deps for grpc on debian bullseye

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* debug

* debug

* Fixups

* no need to install cmake manually

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* ci: fixup macOS

* use brew whenever possible

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* macOS fixups

* debug

* fix container build

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* workaround

* try mac

https://stackoverflow.com/questions/23905661/on-mac-g-clang-fails-to-search-usr-local-include-and-usr-local-lib-by-def

* Disable temp. arm64 docker image builds

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2023-10-16 21:46:29 +02:00 committed by GitHub
parent 8034ed3473
commit 128694213f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 1145 additions and 16 deletions

View File

@ -12,6 +12,9 @@ jobs:
- repository: "go-skynet/go-llama.cpp"
variable: "GOLLAMA_VERSION"
branch: "master"
- repository: "ggerganov/llama.cpp"
variable: "CPPLLAMA_VERSION"
branch: "master"
- repository: "go-skynet/go-ggml-transformers.cpp"
variable: "GOGGMLTRANSFORMERS_VERSION"
branch: "master"

View File

@ -19,7 +19,8 @@ jobs:
matrix:
include:
- build-type: ''
platforms: 'linux/amd64,linux/arm64'
#platforms: 'linux/amd64,linux/arm64'
platforms: 'linux/amd64'
tag-latest: 'auto'
tag-suffix: ''
ffmpeg: ''
@ -38,7 +39,7 @@ jobs:
tag-suffix: '-cublas-cuda12'
ffmpeg: ''
- build-type: ''
platforms: 'linux/amd64,linux/arm64'
platforms: 'linux/amd64'
tag-latest: 'false'
tag-suffix: '-ffmpeg'
ffmpeg: 'true'

View File

@ -29,6 +29,12 @@ jobs:
run: |
sudo apt-get update
sudo apt-get install build-essential ffmpeg
git clone --recurse-submodules -b v1.58.0 --depth 1 --shallow-submodules https://github.com/grpc/grpc && \
cd grpc && mkdir -p cmake/build && cd cmake/build && cmake -DgRPC_INSTALL=ON \
-DgRPC_BUILD_TESTS=OFF \
../.. && sudo make -j12 install
- name: Build
id: build
env:
@ -66,12 +72,20 @@ jobs:
- uses: actions/setup-go@v4
with:
go-version: '>=1.21.0'
- name: Dependencies
run: |
git clone --recurse-submodules -b v1.58.0 --depth 1 --shallow-submodules https://github.com/grpc/grpc && \
cd grpc && mkdir -p cmake/build && cd cmake/build && cmake -DgRPC_INSTALL=ON \
-DgRPC_BUILD_TESTS=OFF \
../.. && make -j12 install && rm -rf grpc
- name: Build
id: build
env:
CMAKE_ARGS: "${{ matrix.defines }}"
BUILD_ID: "${{ matrix.build }}"
run: |
export C_INCLUDE_PATH=/usr/local/include
export CPLUS_INCLUDE_PATH=/usr/local/include
make dist
- uses: actions/upload-artifact@v3
with:

View File

@ -67,11 +67,15 @@ jobs:
run: |
sudo apt-get update
sudo apt-get install build-essential ffmpeg
sudo apt-get install -y ca-certificates cmake curl patch
sudo apt-get install -y libopencv-dev && sudo ln -s /usr/include/opencv4/opencv2 /usr/include/opencv2
sudo pip install -r extra/requirements.txt
# Pre-build stable diffusion before we install a newever version of abseil (not compatible with stablediffusion-ncn)
GO_TAGS="tts stablediffusion" GRPC_BACKENDS=backend-assets/grpc/stablediffusion make build
sudo mkdir /build && sudo chmod -R 777 /build && cd /build && \
curl -L "https://github.com/gabime/spdlog/archive/refs/tags/v1.11.0.tar.gz" | \
tar -xzvf - && \
@ -87,6 +91,12 @@ jobs:
sudo cp -rfv /build/lib/Linux-$(uname -m)/piper_phonemize/lib/. /usr/lib/ && \
sudo ln -s /usr/lib/libpiper_phonemize.so /usr/lib/libpiper_phonemize.so.1 && \
sudo cp -rfv /build/lib/Linux-$(uname -m)/piper_phonemize/include/. /usr/include/
git clone --recurse-submodules -b v1.58.0 --depth 1 --shallow-submodules https://github.com/grpc/grpc && \
cd grpc && mkdir -p cmake/build && cd cmake/build && cmake -DgRPC_INSTALL=ON \
-DgRPC_BUILD_TESTS=OFF \
../.. && sudo make -j12 install
- name: Test
run: |
ESPEAK_DATA="/build/lib/Linux-$(uname -m)/piper_phonemize/lib/espeak-ng-data" GO_TAGS="tts stablediffusion" make test
@ -108,6 +118,14 @@ jobs:
# You can test your matrix by printing the current Go version
- name: Display Go version
run: go version
- name: Dependencies
run: |
git clone --recurse-submodules -b v1.58.0 --depth 1 --shallow-submodules https://github.com/grpc/grpc && \
cd grpc && mkdir -p cmake/build && cd cmake/build && cmake -DgRPC_INSTALL=ON \
-DgRPC_BUILD_TESTS=OFF \
../.. && make -j12 install && rm -rf grpc
- name: Test
run: |
export C_INCLUDE_PATH=/usr/local/include
export CPLUS_INCLUDE_PATH=/usr/local/include
CMAKE_ARGS="-DLLAMA_F16C=OFF -DLLAMA_AVX512=OFF -DLLAMA_AVX2=OFF -DLLAMA_FMA=OFF" make test

View File

@ -16,7 +16,8 @@ ENV GALLERIES='[{"name":"model-gallery", "url":"github:go-skynet/model-gallery/i
ARG GO_TAGS="stablediffusion tts"
RUN apt-get update && \
apt-get install -y ca-certificates cmake curl patch pip
apt-get install -y ca-certificates curl patch pip cmake
# Use the variables in subsequent instructions
RUN echo "Target Architecture: $TARGETARCH"
@ -104,6 +105,15 @@ RUN make prepare
COPY . .
COPY .git .
# stablediffusion does not tolerate a newer version of abseil, build it first
RUN GRPC_BACKENDS=backend-assets/grpc/stablediffusion make build
RUN git clone --recurse-submodules -b v1.58.0 --depth 1 --shallow-submodules https://github.com/grpc/grpc && \
cd grpc && mkdir -p cmake/build && cd cmake/build && cmake -DgRPC_INSTALL=ON \
-DgRPC_BUILD_TESTS=OFF \
../.. && make -j12 install && rm -rf grpc
# Rebuild with defaults backends
RUN ESPEAK_DATA=/build/lib/Linux-$(uname -m)/piper_phonemize/lib/espeak-ng-data make build
###################################
@ -132,8 +142,13 @@ WORKDIR /build
# https://github.com/go-skynet/LocalAI/pull/434
COPY . .
RUN make prepare-sources
# Copy the binary
COPY --from=builder /build/local-ai ./
# do not let piper rebuild (requires an older version of absl)
COPY --from=builder /build/backend-assets/grpc/piper ./backend-assets/grpc/piper
# Copy VALLE-X as it's not a real "lib"
RUN cp -rfv /usr/lib/vall-e-x/* ./

View File

@ -8,6 +8,8 @@ GOLLAMA_VERSION?=1676dcd7a139b6cdfbaea5fd67f46dc25d9d8bcf
GOLLAMA_STABLE_VERSION?=50cee7712066d9e38306eccadcfbb44ea87df4b7
CPPLLAMA_VERSION?=24ba3d829e31a6eda3fa1723f692608c2fa3adda
# gpt4all version
GPT4ALL_REPO?=https://github.com/nomic-ai/gpt4all
GPT4ALL_VERSION?=27a8b020c36b0df8f8b82a252d261cda47cf44b8
@ -120,7 +122,7 @@ ifeq ($(findstring tts,$(GO_TAGS)),tts)
OPTIONAL_GRPC+=backend-assets/grpc/piper
endif
GRPC_BACKENDS?=backend-assets/grpc/langchain-huggingface backend-assets/grpc/falcon-ggml backend-assets/grpc/bert-embeddings backend-assets/grpc/falcon backend-assets/grpc/bloomz backend-assets/grpc/llama backend-assets/grpc/llama-stable backend-assets/grpc/gpt4all backend-assets/grpc/dolly backend-assets/grpc/gpt2 backend-assets/grpc/gptj backend-assets/grpc/gptneox backend-assets/grpc/mpt backend-assets/grpc/replit backend-assets/grpc/starcoder backend-assets/grpc/rwkv backend-assets/grpc/whisper $(OPTIONAL_GRPC)
GRPC_BACKENDS?=backend-assets/grpc/langchain-huggingface backend-assets/grpc/falcon-ggml backend-assets/grpc/bert-embeddings backend-assets/grpc/falcon backend-assets/grpc/bloomz backend-assets/grpc/llama backend-assets/grpc/llama-cpp backend-assets/grpc/llama-stable backend-assets/grpc/gpt4all backend-assets/grpc/dolly backend-assets/grpc/gpt2 backend-assets/grpc/gptj backend-assets/grpc/gptneox backend-assets/grpc/mpt backend-assets/grpc/replit backend-assets/grpc/starcoder backend-assets/grpc/rwkv backend-assets/grpc/whisper $(OPTIONAL_GRPC)
.PHONY: all test build vendor
@ -223,7 +225,7 @@ go-llama/libbinding.a: go-llama
go-llama-stable/libbinding.a: go-llama-stable
$(MAKE) -C go-llama-stable BUILD_TYPE=$(STABLE_BUILD_TYPE) libbinding.a
go-piper/libpiper_binding.a:
go-piper/libpiper_binding.a: go-piper
$(MAKE) -C go-piper libpiper_binding.a example/main
get-sources: go-llama go-llama-stable go-ggllm go-ggml-transformers gpt4all go-piper go-rwkv whisper.cpp go-bert bloomz go-stable-diffusion
@ -280,6 +282,7 @@ clean: ## Remove build related file
rm -rf ./go-ggllm
rm -rf $(BINARY_NAME)
rm -rf release/
$(MAKE) -C backend/cpp/llama clean
## Build:
@ -395,6 +398,16 @@ ifeq ($(BUILD_TYPE),metal)
cp go-llama/build/bin/ggml-metal.metal backend-assets/grpc/
endif
backend/cpp/llama/grpc-server:
LLAMA_VERSION=$(CPPLLAMA_VERSION) $(MAKE) -C backend/cpp/llama grpc-server
backend-assets/grpc/llama-cpp: backend-assets/grpc backend/cpp/llama/grpc-server
cp -rfv backend/cpp/llama/grpc-server backend-assets/grpc/llama-cpp
# TODO: every binary should have its own folder instead, so can have different metal implementations
ifeq ($(BUILD_TYPE),metal)
cp backend/cpp/llama/llama.cpp/build/bin/ggml-metal.metal backend-assets/grpc/
endif
backend-assets/grpc/llama-stable: backend-assets/grpc go-llama-stable/libbinding.a
$(GOCMD) mod edit -replace github.com/go-skynet/go-llama.cpp=$(shell pwd)/go-llama-stable
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-llama-stable LIBRARY_PATH=$(shell pwd)/go-llama \
@ -451,9 +464,12 @@ backend-assets/grpc/bert-embeddings: backend-assets/grpc go-bert/libgobert.a
backend-assets/grpc/langchain-huggingface: backend-assets/grpc
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/langchain-huggingface ./cmd/grpc/langchain-huggingface/
backend-assets/grpc/stablediffusion: backend-assets/grpc go-stable-diffusion/libstablediffusion.a
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-stable-diffusion/ LIBRARY_PATH=$(shell pwd)/go-stable-diffusion/ \
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/stablediffusion ./cmd/grpc/stablediffusion/
backend-assets/grpc/stablediffusion: backend-assets/grpc
if [ ! -f backend-assets/grpc/stablediffusion ]; then \
$(MAKE) go-stable-diffusion/libstablediffusion.a; \
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-stable-diffusion/ LIBRARY_PATH=$(shell pwd)/go-stable-diffusion/ \
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/stablediffusion ./cmd/grpc/stablediffusion/; \
fi
backend-assets/grpc/piper: backend-assets/grpc backend-assets/espeak-ng-data go-piper/libpiper_binding.a
CGO_LDFLAGS="$(CGO_LDFLAGS)" LIBRARY_PATH=$(shell pwd)/go-piper \

View File

@ -0,0 +1,57 @@
set(CMAKE_CXX_STANDARD 17)
cmake_minimum_required(VERSION 3.15)
set(TARGET grpc-server)
set(_PROTOBUF_LIBPROTOBUF libprotobuf)
set(_REFLECTION grpc++_reflection)
find_package(absl CONFIG REQUIRED)
find_package(Protobuf CONFIG REQUIRED)
find_package(gRPC CONFIG REQUIRED)
find_program(_PROTOBUF_PROTOC protoc)
set(_GRPC_GRPCPP grpc++)
find_program(_GRPC_CPP_PLUGIN_EXECUTABLE grpc_cpp_plugin)
include_directories(${CMAKE_CURRENT_BINARY_DIR})
include_directories(${Protobuf_INCLUDE_DIRS})
message(STATUS "Using protobuf ${Protobuf_VERSION} ${Protobuf_INCLUDE_DIRS} ${CMAKE_CURRENT_BINARY_DIR}")
# Proto file
get_filename_component(hw_proto "../../../../../../pkg/grpc/proto/backend.proto" ABSOLUTE)
get_filename_component(hw_proto_path "${hw_proto}" PATH)
# Generated sources
set(hw_proto_srcs "${CMAKE_CURRENT_BINARY_DIR}/backend.pb.cc")
set(hw_proto_hdrs "${CMAKE_CURRENT_BINARY_DIR}/backend.pb.h")
set(hw_grpc_srcs "${CMAKE_CURRENT_BINARY_DIR}/backend.grpc.pb.cc")
set(hw_grpc_hdrs "${CMAKE_CURRENT_BINARY_DIR}/backend.grpc.pb.h")
add_custom_command(
OUTPUT "${hw_proto_srcs}" "${hw_proto_hdrs}" "${hw_grpc_srcs}" "${hw_grpc_hdrs}"
COMMAND ${_PROTOBUF_PROTOC}
ARGS --grpc_out "${CMAKE_CURRENT_BINARY_DIR}"
--cpp_out "${CMAKE_CURRENT_BINARY_DIR}"
-I "${hw_proto_path}"
--plugin=protoc-gen-grpc="${_GRPC_CPP_PLUGIN_EXECUTABLE}"
"${hw_proto}"
DEPENDS "${hw_proto}")
# hw_grpc_proto
add_library(hw_grpc_proto
${hw_grpc_srcs}
${hw_grpc_hdrs}
${hw_proto_srcs}
${hw_proto_hdrs})
add_executable(${TARGET} grpc-server.cpp)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT} absl::flags hw_grpc_proto
absl::flags_parse
gRPC::${_REFLECTION}
gRPC::${_GRPC_GRPCPP}
protobuf::${_PROTOBUF_LIBPROTOBUF})
target_compile_features(${TARGET} PRIVATE cxx_std_11)
if(TARGET BUILD_INFO)
add_dependencies(${TARGET} BUILD_INFO)
endif()

View File

@ -0,0 +1,44 @@
LLAMA_VERSION?=24ba3d829e31a6eda3fa1723f692608c2fa3adda
CMAKE_ARGS?=
BUILD_TYPE?=
# If build type is cublas, then we set -DLLAMA_CUBLAS=ON to CMAKE_ARGS automatically
ifeq ($(BUILD_TYPE),cublas)
CMAKE_ARGS+=-DLLAMA_CUBLAS=ON
# If build type is openblas then we set -DLLAMA_BLAS=ON -DLLAMA_BLAS_VENDOR=OpenBLAS
# to CMAKE_ARGS automatically
else ifeq ($(BUILD_TYPE),openblas)
CMAKE_ARGS+=-DLLAMA_BLAS=ON -DLLAMA_BLAS_VENDOR=OpenBLAS
# If build type is clblast (openCL) we set -DLLAMA_CLBLAST=ON -DCLBlast_DIR=/some/path
else ifeq ($(BUILD_TYPE),clblast)
CMAKE_ARGS+=-DLLAMA_CLBLAST=ON -DCLBlast_DIR=/some/path
# If it's hipblas we do have also to set CC=/opt/rocm/llvm/bin/clang CXX=/opt/rocm/llvm/bin/clang++
else ifeq ($(BUILD_TYPE),hipblas)
CMAKE_ARGS+=-DLLAMA_HIPBLAS=ON
endif
llama.cpp:
git clone --recurse-submodules https://github.com/ggerganov/llama.cpp llama.cpp
cd llama.cpp && git checkout -b build $(LLAMA_VERSION) && git submodule update --init --recursive --depth 1
llama.cpp/examples/grpc-server:
mkdir -p llama.cpp/examples/grpc-server
cp -r $(abspath ./)/CMakeLists.txt llama.cpp/examples/grpc-server/
cp -r $(abspath ./)/grpc-server.cpp llama.cpp/examples/grpc-server/
echo "add_subdirectory(grpc-server)" >> llama.cpp/examples/CMakeLists.txt
rebuild:
cp -rfv $(abspath ./)/CMakeLists.txt llama.cpp/examples/grpc-server/
cp -rfv $(abspath ./)/grpc-server.cpp llama.cpp/examples/grpc-server/
rm -rf grpc-server
$(MAKE) grpc-server
clean:
rm -rf llama.cpp
rm -rf grpc-server
grpc-server: llama.cpp llama.cpp/examples/grpc-server
cd llama.cpp && mkdir -p build && cd build && cmake .. $(CMAKE_ARGS) && cmake --build . --config Release
cp llama.cpp/build/bin/grpc-server .

View File

@ -0,0 +1,964 @@
// llama.cpp gRPC C++ backend server
//
// Ettore Di Giacinto <mudler@localai.io>
//
// This is a gRPC server for llama.cpp compatible with the LocalAI proto
// Note: this is a re-adaptation of the original llama.cpp example/server.cpp for HTTP,
// but modified to work with gRPC
//
#include <iostream>
#include <memory>
#include <string>
#include <getopt.h>
#include "common.h"
#include "llama.h"
#include "grammar-parser.h"
#include "backend.pb.h"
#include "backend.grpc.pb.h"
// include std::regex
#include <regex>
#include <grpcpp/ext/proto_server_reflection_plugin.h>
#include <grpcpp/grpcpp.h>
#include <grpcpp/health_check_service_interface.h>
using grpc::Server;
using grpc::ServerBuilder;
using grpc::ServerContext;
using grpc::Status;
using backend::HealthMessage;
// completion token output with probabilities
struct completion_token_output
{
struct token_prob
{
llama_token tok;
float prob;
};
std::vector<token_prob> probs;
llama_token tok;
};
static size_t common_part(const std::vector<llama_token> &a, const std::vector<llama_token> &b)
{
size_t i;
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++)
{
}
return i;
}
enum stop_type
{
STOP_FULL,
STOP_PARTIAL,
};
static bool ends_with(const std::string &str, const std::string &suffix)
{
return str.size() >= suffix.size() &&
0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
}
static size_t find_partial_stop_string(const std::string &stop,
const std::string &text)
{
if (!text.empty() && !stop.empty())
{
const char text_last_char = text.back();
for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--)
{
if (stop[char_index] == text_last_char)
{
const std::string current_partial = stop.substr(0, char_index + 1);
if (ends_with(text, current_partial))
{
return text.size() - char_index - 1;
}
}
}
}
return std::string::npos;
}
template <class Iter>
static std::string tokens_to_str(llama_context *ctx, Iter begin, Iter end)
{
std::string ret;
for (; begin != end; ++begin)
{
ret += llama_token_to_piece(ctx, *begin);
}
return ret;
}
// format incomplete utf-8 multibyte character for output
static std::string tokens_to_output_formatted_string(const llama_context *ctx, const llama_token token)
{
std::string out = token == -1 ? "" : llama_token_to_piece(ctx, token);
// if the size is 1 and first bit is 1, meaning it's a partial character
// (size > 1 meaning it's already a known token)
if (out.size() == 1 && (out[0] & 0x80) == 0x80)
{
std::stringstream ss;
ss << std::hex << (out[0] & 0xff);
std::string res(ss.str());
out = "byte: \\x" + res;
}
return out;
}
struct llama_server_context
{
bool stream = false;
bool has_next_token = false;
std::string generated_text;
std::vector<completion_token_output> generated_token_probs;
size_t num_prompt_tokens = 0;
size_t num_tokens_predicted = 0;
size_t n_past = 0;
size_t n_remain = 0;
// json prompt;
std::vector<llama_token> embd;
std::vector<llama_token> last_n_tokens;
llama_model *model = nullptr;
llama_context *ctx = nullptr;
gpt_params params;
int n_ctx;
grammar_parser::parse_state parsed_grammar;
llama_grammar *grammar = nullptr;
bool truncated = false;
bool stopped_eos = false;
bool stopped_word = false;
bool stopped_limit = false;
std::string stopping_word;
int32_t multibyte_pending = 0;
std::mutex mutex;
std::unique_lock<std::mutex> lock()
{
return std::unique_lock<std::mutex>(mutex);
}
~llama_server_context()
{
if (ctx)
{
llama_free(ctx);
ctx = nullptr;
}
if (model)
{
llama_free_model(model);
model = nullptr;
}
}
void rewind()
{
params.antiprompt.clear();
params.grammar.clear();
num_prompt_tokens = 0;
num_tokens_predicted = 0;
generated_text = "";
generated_text.reserve(n_ctx);
generated_token_probs.clear();
truncated = false;
stopped_eos = false;
stopped_word = false;
stopped_limit = false;
stopping_word = "";
multibyte_pending = 0;
n_remain = 0;
n_past = 0;
if (grammar != nullptr) {
llama_grammar_free(grammar);
grammar = nullptr;
}
}
bool loadModel(const gpt_params &params_)
{
printf("load model %s\n", params_.model.c_str());
params = params_;
std::tie(model, ctx) = llama_init_from_gpt_params(params);
if (model == nullptr)
{
printf("unable to load model %s\n", params_.model.c_str());
return false;
}
n_ctx = llama_n_ctx(ctx);
last_n_tokens.resize(n_ctx);
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
return true;
}
std::vector<llama_token> tokenize_array(const char **prompts, bool add_bos) const
{
// If `add_bos` is true, we only add BOS, when json_prompt is a string,
// or the first element of the json_prompt array is a string.
std::vector<llama_token> prompt_tokens;
bool first = true;
// Iterate over prompts
for (const char **p = prompts; *p != nullptr; ++p)
{
auto s = std::string(*p);
std::vector<llama_token> pp;
if (first)
{
pp = ::llama_tokenize(ctx, s, add_bos);
first = false;
}
else
{
pp = ::llama_tokenize(ctx, s, false);
}
prompt_tokens.insert(prompt_tokens.end(), pp.begin(), pp.end());
}
return prompt_tokens;
}
std::vector<llama_token> tokenize_string(const char *prompt, bool add_bos) const
{
// If `add_bos` is true, we only add BOS, when json_prompt is a string,
// or the first element of the json_prompt array is a string.
std::vector<llama_token> prompt_tokens;
auto s = std::string(prompt);
prompt_tokens = ::llama_tokenize(ctx, s, add_bos);
return prompt_tokens;
}
bool loadGrammar()
{
if (!params.grammar.empty()) {
parsed_grammar = grammar_parser::parse(params.grammar.c_str());
// will be empty (default) if there are parse errors
if (parsed_grammar.rules.empty()) {
printf("grammar parse error");
return false;
}
grammar_parser::print_grammar(stderr, parsed_grammar);
{
auto it = params.logit_bias.find(llama_token_eos(ctx));
if (it != params.logit_bias.end() && it->second == -INFINITY) {
printf("EOS token is disabled, which will cause most grammars to fail");
}
}
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
grammar = llama_grammar_init(
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
}
return true;
}
void loadInfill()
{
bool suff_rm_leading_spc = true;
if (params.input_suffix.find_first_of(" ") == 0 && params.input_suffix.size() > 1) {
params.input_suffix.erase(0, 1);
suff_rm_leading_spc = false;
}
auto prefix_tokens = tokenize_string(params.input_prefix.c_str(), false);
auto suffix_tokens = tokenize_string(params.input_suffix.c_str(), false);
const int space_token = 29871;
if (suff_rm_leading_spc && suffix_tokens[0] == space_token) {
suffix_tokens.erase(suffix_tokens.begin());
}
prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(ctx));
prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(ctx)); // always add BOS
prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(ctx));
prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end());
prefix_tokens.push_back(llama_token_middle(ctx));
auto prompt_tokens = prefix_tokens;
num_prompt_tokens = prompt_tokens.size();
if (params.n_keep < 0)
{
params.n_keep = (int)num_prompt_tokens;
}
params.n_keep = std::min(params.n_ctx - 4, params.n_keep);
// if input prompt is too big, truncate like normal
if (num_prompt_tokens >= (size_t)params.n_ctx)
{
printf("Input prompt is too big, truncating. Can only take %d tokens but got %zu\n", params.n_ctx, num_prompt_tokens);
// todo we probably want to cut from both sides
const int n_left = (params.n_ctx - params.n_keep) / 2;
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left;
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end());
std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), last_n_tokens.begin());
truncated = true;
prompt_tokens = new_tokens;
}
else
{
const size_t ps = num_prompt_tokens;
std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0);
std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps);
}
// compare the evaluated prompt with the new prompt
n_past = common_part(embd, prompt_tokens);
embd = prompt_tokens;
if (n_past == num_prompt_tokens)
{
// we have to evaluate at least 1 token to generate logits.
printf("we have to evaluate at least 1 token to generate logits\n");
n_past--;
}
llama_kv_cache_seq_rm(ctx, 0, n_past, -1);
has_next_token = true;
}
void loadPrompt(std::string prompt)
{
auto prompt_tokens = tokenize_string(prompt.c_str(), true); // always add BOS
num_prompt_tokens = prompt_tokens.size();
if (params.n_keep < 0)
{
params.n_keep = (int)num_prompt_tokens;
}
params.n_keep = std::min(n_ctx - 4, params.n_keep);
// if input prompt is too big, truncate like normal
if (num_prompt_tokens >= (size_t)n_ctx)
{
const int n_left = (n_ctx - params.n_keep) / 2;
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left;
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end());
std::copy(prompt_tokens.end() - n_ctx, prompt_tokens.end(), last_n_tokens.begin());
truncated = true;
prompt_tokens = new_tokens;
}
else
{
const size_t ps = num_prompt_tokens;
std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0);
std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps);
}
// compare the evaluated prompt with the new prompt
n_past = common_part(embd, prompt_tokens);
embd = prompt_tokens;
if (n_past == num_prompt_tokens)
{
// we have to evaluate at least 1 token to generate logits.
n_past--;
}
// since #3228 we now have to manually manage the KV cache
llama_kv_cache_seq_rm(ctx, 0, n_past, -1);
has_next_token = true;
}
void beginCompletion()
{
// number of tokens to keep when resetting context
n_remain = params.n_predict;
llama_set_rng_seed(ctx, params.seed);
}
completion_token_output nextToken()
{
completion_token_output result;
result.tok = -1;
if (embd.size() >= (size_t)n_ctx)
{
// Shift context
const int n_left = n_past - params.n_keep - 1;
const int n_discard = n_left/2;
llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
for (size_t i = params.n_keep + 1 + n_discard; i < embd.size(); i++)
{
embd[i - n_discard] = embd[i];
}
embd.resize(embd.size() - n_discard);
n_past -= n_discard;
truncated = true;
}
bool tg = true;
while (n_past < embd.size())
{
int n_eval = (int)embd.size() - n_past;
tg = n_eval == 1;
if (n_eval > params.n_batch)
{
n_eval = params.n_batch;
}
if (llama_decode(ctx, llama_batch_get_one(&embd[n_past], n_eval, n_past, 0)))
{
has_next_token = false;
return result;
}
n_past += n_eval;
}
if (params.n_predict == 0)
{
has_next_token = false;
result.tok = llama_token_eos(ctx);
return result;
}
{
// out of user input, sample next token
std::vector<llama_token_data> candidates;
candidates.reserve(llama_n_vocab(model));
result.tok = llama_sample_token(ctx, NULL, grammar, params, last_n_tokens, candidates);
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
const int32_t n_probs = params.n_probs;
if (params.temp <= 0 && n_probs > 0)
{
// For llama_sample_token_greedy we need to sort candidates
llama_sample_softmax(ctx, &candidates_p);
}
for (size_t i = 0; i < std::min(candidates_p.size, (size_t)n_probs); ++i)
{
result.probs.push_back({candidates_p.data[i].id, candidates_p.data[i].p});
}
last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(result.tok);
if (tg) {
num_tokens_predicted++;
}
}
// add it to the context
embd.push_back(result.tok);
// decrement remaining sampling budget
--n_remain;
if (!embd.empty() && embd.back() == llama_token_eos(ctx))
{
// stopping_word = llama_token_to_piece(ctx, embd.back());
has_next_token = false;
stopped_eos = true;
return result;
}
has_next_token = params.n_predict == -1 || n_remain != 0;
return result;
}
size_t findStoppingStrings(const std::string &text, const size_t last_token_size,
const stop_type type)
{
size_t stop_pos = std::string::npos;
for (const std::string &word : params.antiprompt)
{
size_t pos;
if (type == STOP_FULL)
{
const size_t tmp = word.size() + last_token_size;
const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0;
pos = text.find(word, from_pos);
}
else
{
pos = find_partial_stop_string(word, text);
}
if (pos != std::string::npos &&
(stop_pos == std::string::npos || pos < stop_pos))
{
if (type == STOP_FULL)
{
stopping_word = word;
stopped_word = true;
has_next_token = false;
}
stop_pos = pos;
}
}
return stop_pos;
}
completion_token_output doCompletion()
{
auto token_with_probs = nextToken();
const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_piece(ctx, token_with_probs.tok);
generated_text += token_text;
if (params.n_probs > 0)
{
generated_token_probs.push_back(token_with_probs);
}
if (multibyte_pending > 0)
{
multibyte_pending -= token_text.size();
}
else if (token_text.size() == 1)
{
const char c = token_text[0];
// 2-byte characters: 110xxxxx 10xxxxxx
if ((c & 0xE0) == 0xC0)
{
multibyte_pending = 1;
// 3-byte characters: 1110xxxx 10xxxxxx 10xxxxxx
}
else if ((c & 0xF0) == 0xE0)
{
multibyte_pending = 2;
// 4-byte characters: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
}
else if ((c & 0xF8) == 0xF0)
{
multibyte_pending = 3;
}
else
{
multibyte_pending = 0;
}
}
if (multibyte_pending > 0 && !has_next_token)
{
has_next_token = true;
n_remain++;
}
if (!has_next_token && n_remain == 0)
{
stopped_limit = true;
}
return token_with_probs;
}
std::vector<float> getEmbedding()
{
static const int n_embd = llama_n_embd(model);
if (!params.embedding)
{
printf("embedding disabled");
return std::vector<float>(n_embd, 0.0f);
}
const float *data = llama_get_embeddings(ctx);
std::vector<float> embedding(data, data + n_embd);
return embedding;
}
};
static void parse_options_completion(bool streaming,const backend::PredictOptions* predict, llama_server_context &llama)
{
gpt_params default_params;
llama.stream = streaming;
llama.params.n_predict = predict->tokens() == 0 ? -1 : predict->tokens();
llama.params.top_k = predict->topk();
llama.params.top_p = predict->topp();
llama.params.tfs_z = predict->tailfreesamplingz();
llama.params.typical_p = predict->typicalp();
llama.params.repeat_last_n = predict->repeat();
llama.params.temp = predict->temperature();
llama.params.repeat_penalty = predict->penalty();
llama.params.presence_penalty = predict->presencepenalty();
llama.params.frequency_penalty = predict->frequencypenalty();
llama.params.mirostat = predict->mirostat();
llama.params.mirostat_tau = predict->mirostattau();
llama.params.mirostat_eta = predict->mirostateta();
llama.params.penalize_nl = predict->penalizenl();
llama.params.n_keep = predict->nkeep();
llama.params.seed = predict->seed();
llama.params.grammar = predict->grammar();
// llama.params.n_probs = predict->
llama.params.prompt = predict->prompt();
llama.params.logit_bias.clear();
if (predict->ignoreeos())
{
llama.params.logit_bias[llama_token_eos(llama.ctx)] = -INFINITY;
}
// const auto &logit_bias = body.find("logit_bias");
// if (logit_bias != body.end() && logit_bias->is_array())
// {
// const int n_vocab = llama_n_vocab(llama.model);
// for (const auto &el : *logit_bias)
// {
// if (el.is_array() && el.size() == 2 && el[0].is_number_integer())
// {
// llama_token tok = el[0].get<llama_token>();
// if (tok >= 0 && tok < n_vocab)
// {
// if (el[1].is_number())
// {
// llama.params.logit_bias[tok] = el[1].get<float>();
// }
// else if (el[1].is_boolean() && !el[1].get<bool>())
// {
// llama.params.logit_bias[tok] = -INFINITY;
// }
// }
// }
// }
// }
llama.params.antiprompt.clear();
for (const std::string& stopPrompt : predict->stopprompts()) {
if (!stopPrompt.empty())
{
llama.params.antiprompt.push_back(stopPrompt);
}
}
}
static void params_parse(const backend::ModelOptions* request,
gpt_params & params) {
params.model = request->modelfile();
// params.model_alias ??
params.model_alias = request->modelfile();
params.n_ctx = request->contextsize();
params.memory_f16 = request->f16memory();
params.n_threads = request->threads();
params.n_gpu_layers = request->ngpulayers();
params.n_batch = request->nbatch();
if (!request->tensorsplit().empty()) {
std::string arg_next = request->tensorsplit();
// split string by , and /
const std::regex regex{ R"([,/]+)" };
std::sregex_token_iterator it{ arg_next.begin(), arg_next.end(), regex, -1 };
std::vector<std::string> split_arg{ it, {} };
GGML_ASSERT(split_arg.size() <= LLAMA_MAX_DEVICES);
for (size_t i_device = 0; i_device < LLAMA_MAX_DEVICES; ++i_device) {
if (i_device < split_arg.size()) {
params.tensor_split[i_device] = std::stof(split_arg[i_device]);
}
else {
params.tensor_split[i_device] = 0.0f;
}
}
}
if (!request->maingpu().empty()) {
params.main_gpu = std::stoi(request->maingpu());
}
// TODO: lora needs also a scale factor
//params.lora_adapter = request->loraadapter();
//params.lora_base = request->lorabase();
params.use_mlock = request->mlock();
params.use_mmap = request->mmap();
params.embedding = request->embeddings();
}
static bool is_at_eob(llama_server_context &server_context, const llama_token *tokens, const size_t n_tokens) {
return n_tokens && tokens[n_tokens-1] == llama_token_eos(server_context.ctx);
}
// Function matching type llama_beam_search_callback_fn_t.
// Custom callback example is called each time the beams lengths increase:
// * Show progress by printing ',' following by number of convergent beam tokens if any.
// * When all beams converge to a common prefix, they are made available in beams_state.beams[0].
// This is also called when the stop condition is met.
// Collect tokens into std::vector<llama_token> response which is pointed to by callback_data.
static void beam_search_callback(void *callback_data, llama_beams_state beams_state) {
auto & llama = *static_cast<llama_server_context*>(callback_data);
// Mark beams as EOS as needed.
for (size_t i = 0 ; i < beams_state.n_beams ; ++i) {
llama_beam_view& beam_view = beams_state.beam_views[i];
if (!beam_view.eob && is_at_eob(llama, beam_view.tokens, beam_view.n_tokens)) {
beam_view.eob = true;
}
}
printf(","); // Show progress
if (const size_t n = beams_state.common_prefix_length) {
llama.generated_token_probs.resize(llama.generated_token_probs.size() + n);
assert(0u < beams_state.n_beams);
const llama_token * tokens = beams_state.beam_views[0].tokens;
const auto map = [](llama_token tok) { return completion_token_output{{},tok}; };
std::transform(tokens, tokens + n, llama.generated_token_probs.end() - n, map);
printf("%zu", n);
}
fflush(stdout);
#if 0 // DEBUG: print current beams for this iteration
std::cout << "\n\nCurrent beams:\n";
for (size_t i=0 ; i < beams_state.n_beams ; ++i) {
std::cout << "beams["<<i<<"]: " << ostream_beam_view{state.ctx,beams_state.beam_views[i]} << std::endl;
}
#endif
}
struct token_translator {
llama_context * ctx;
std::string operator()(llama_token tok) const { return llama_token_to_piece(ctx, tok); }
std::string operator()(const completion_token_output & cto) const { return (*this)(cto.tok); }
};
static void append_to_generated_text_from_generated_token_probs(llama_server_context &llama)
{
auto & gtps = llama.generated_token_probs;
auto translator = token_translator{llama.ctx};
auto add_strlen = [=](size_t sum, const completion_token_output & cto) { return sum + translator(cto).size(); };
const size_t len = std::accumulate(gtps.begin(), gtps.end(), size_t(0), add_strlen);
if (llama.generated_text.capacity() < llama.generated_text.size() + len) {
llama.generated_text.reserve(llama.generated_text.size() + len);
}
for (const completion_token_output & cto : gtps) {
llama.generated_text += translator(cto);
}
}
// GRPC Server start
class BackendServiceImpl final : public backend::Backend::Service {
// The class has a llama instance that is shared across all RPCs
llama_server_context llama;
public:
grpc::Status Health(ServerContext* context, const backend::HealthMessage* request, backend::Reply* reply) {
// Implement Health RPC
reply->set_message("OK");
return Status::OK;
}
grpc::Status LoadModel(ServerContext* context, const backend::ModelOptions* request, backend::Result* result) {
// Implement LoadModel RPC
gpt_params params;
params_parse(request, params);
llama_backend_init(params.numa);
// load the model
if (!llama.loadModel(params))
{
result->set_message("Failed loading model");
result->set_success(false);
return Status::CANCELLED;
}
result->set_message("Loading succeeded");
result->set_success(true);
return Status::OK;
}
grpc::Status PredictStream(grpc::ServerContext* context, const backend::PredictOptions* request, grpc::ServerWriter<backend::Reply>* writer) override {
// Implement the streaming logic here based on the request options
// You can use writer->Write(response) to send a reply to the client
// and return grpc::Status::OK when the operation is complete.
auto lock = llama.lock();
llama.rewind();
llama_reset_timings(llama.ctx);
parse_options_completion(false, request, llama);
if (!llama.loadGrammar())
{
//res.status = 400;
return Status::CANCELLED;
}
llama.loadPrompt(request->prompt());
llama.beginCompletion();
size_t sent_count = 0;
size_t sent_token_probs_index = 0;
while (llama.has_next_token) {
const completion_token_output token_with_probs = llama.doCompletion();
if (token_with_probs.tok == -1 || llama.multibyte_pending > 0) {
continue;
}
const std::string token_text = llama_token_to_piece(llama.ctx, token_with_probs.tok);
size_t pos = std::min(sent_count, llama.generated_text.size());
const std::string str_test = llama.generated_text.substr(pos);
bool is_stop_full = false;
size_t stop_pos =
llama.findStoppingStrings(str_test, token_text.size(), STOP_FULL);
if (stop_pos != std::string::npos) {
is_stop_full = true;
llama.generated_text.erase(
llama.generated_text.begin() + pos + stop_pos,
llama.generated_text.end());
pos = std::min(sent_count, llama.generated_text.size());
} else {
is_stop_full = false;
stop_pos = llama.findStoppingStrings(str_test, token_text.size(),
STOP_PARTIAL);
}
if (
stop_pos == std::string::npos ||
// Send rest of the text if we are at the end of the generation
(!llama.has_next_token && !is_stop_full && stop_pos > 0)
) {
const std::string to_send = llama.generated_text.substr(pos, std::string::npos);
sent_count += to_send.size();
std::vector<completion_token_output> probs_output = {};
if (llama.params.n_probs > 0) {
const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, to_send, false);
size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size());
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size());
if (probs_pos < probs_stop_pos) {
probs_output = std::vector<completion_token_output>(llama.generated_token_probs.begin() + probs_pos, llama.generated_token_probs.begin() + probs_stop_pos);
}
sent_token_probs_index = probs_stop_pos;
}
backend::Reply reply;
reply.set_message(to_send);
// Send the reply
writer->Write(reply);
}
}
llama_print_timings(llama.ctx);
llama.mutex.unlock();
lock.release();
return grpc::Status::OK;
}
grpc::Status Predict(ServerContext* context, const backend::PredictOptions* request, backend::Reply* reply) {
auto lock = llama.lock();
llama.rewind();
llama_reset_timings(llama.ctx);
parse_options_completion(false, request, llama);
if (!llama.loadGrammar())
{
//res.status = 400;
return Status::CANCELLED;
}
llama.loadPrompt(request->prompt());
llama.beginCompletion();
if (llama.params.n_beams) {
// Fill llama.generated_token_probs vector with final beam.
llama_beam_search(llama.ctx, beam_search_callback, &llama, llama.params.n_beams,
llama.n_past, llama.n_remain);
// Translate llama.generated_token_probs to llama.generated_text.
append_to_generated_text_from_generated_token_probs(llama);
} else {
size_t stop_pos = std::string::npos;
while (llama.has_next_token) {
const completion_token_output token_with_probs = llama.doCompletion();
const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_piece(llama.ctx, token_with_probs.tok);
stop_pos = llama.findStoppingStrings(llama.generated_text,
token_text.size(), STOP_FULL);
}
if (stop_pos == std::string::npos) {
stop_pos = llama.findStoppingStrings(llama.generated_text, 0, STOP_PARTIAL);
}
if (stop_pos != std::string::npos) {
llama.generated_text.erase(llama.generated_text.begin() + stop_pos,
llama.generated_text.end());
}
}
auto probs = llama.generated_token_probs;
if (llama.params.n_probs > 0 && llama.stopped_word) {
const std::vector<llama_token> stop_word_toks = llama_tokenize(llama.ctx, llama.stopping_word, false);
probs = std::vector<completion_token_output>(llama.generated_token_probs.begin(), llama.generated_token_probs.end() - stop_word_toks.size());
}
reply->set_message(llama.generated_text);
return grpc::Status::OK;
}
};
void RunServer(const std::string& server_address) {
BackendServiceImpl service;
ServerBuilder builder;
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
builder.RegisterService(&service);
std::unique_ptr<Server> server(builder.BuildAndStart());
std::cout << "Server listening on " << server_address << std::endl;
server->Wait();
}
int main(int argc, char** argv) {
std::string server_address("localhost:50051");
// Define long and short options
struct option long_options[] = {
{"addr", required_argument, nullptr, 'a'},
{nullptr, 0, nullptr, 0}
};
// Parse command-line arguments
int option;
int option_index = 0;
while ((option = getopt_long(argc, argv, "a:", long_options, &option_index)) != -1) {
switch (option) {
case 'a':
server_address = optarg;
break;
default:
std::cerr << "Usage: " << argv[0] << " [--addr=<address>] or [-a <address>]" << std::endl;
return 1;
}
}
RunServer(server_address);
return 0;
}

View File

@ -17,6 +17,7 @@ import (
const (
LlamaBackend = "llama"
LlamaStableBackend = "llama-stable"
LLamaCPP = "llama-cpp"
BloomzBackend = "bloomz"
StarcoderBackend = "starcoder"
GPTJBackend = "gptj"
@ -41,8 +42,9 @@ const (
)
var AutoLoadBackends []string = []string{
LlamaBackend,
LLamaCPP,
LlamaStableBackend,
LlamaBackend,
Gpt4All,
FalconBackend,
GPTNeoXBackend,
@ -175,11 +177,6 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (model *grpc.Client, err er
}
switch backend {
case LlamaBackend, LlamaStableBackend, GPTJBackend, DollyBackend,
MPTBackend, Gpt2Backend, FalconBackend,
GPTNeoXBackend, ReplitBackend, StarcoderBackend, BloomzBackend,
RwkvBackend, LCHuggingFaceBackend, BertEmbeddingsBackend, FalconGGMLBackend, StableDiffusionBackend, WhisperBackend:
return ml.LoadModel(o.model, ml.grpcModel(backend, o))
case Gpt4AllLlamaBackend, Gpt4AllMptBackend, Gpt4AllJBackend, Gpt4All:
o.gRPCOptions.LibrarySearchPath = filepath.Join(o.assetDir, "backend-assets", "gpt4all")
return ml.LoadModel(o.model, ml.grpcModel(Gpt4All, o))
@ -187,7 +184,7 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (model *grpc.Client, err er
o.gRPCOptions.LibrarySearchPath = filepath.Join(o.assetDir, "backend-assets", "espeak-ng-data")
return ml.LoadModel(o.model, ml.grpcModel(PiperBackend, o))
default:
return nil, fmt.Errorf("backend unsupported: %s", o.backendString)
return ml.LoadModel(o.model, ml.grpcModel(backend, o))
}
}