diff --git a/Dockerfile b/Dockerfile index 55793a72..7f7ee817 100644 --- a/Dockerfile +++ b/Dockerfile @@ -16,7 +16,7 @@ ENV BUILD_TYPE=${BUILD_TYPE} ENV EXTERNAL_GRPC_BACKENDS="coqui:/build/backend/python/coqui/run.sh,huggingface-embeddings:/build/backend/python/sentencetransformers/run.sh,petals:/build/backend/python/petals/run.sh,transformers:/build/backend/python/transformers/run.sh,sentencetransformers:/build/backend/python/sentencetransformers/run.sh,autogptq:/build/backend/python/autogptq/run.sh,bark:/build/backend/python/bark/run.sh,diffusers:/build/backend/python/diffusers/run.sh,exllama:/build/backend/python/exllama/run.sh,vall-e-x:/build/backend/python/vall-e-x/run.sh,vllm:/build/backend/python/vllm/run.sh,exllama2:/build/backend/python/exllama2/run.sh,transformers-musicgen:/build/backend/python/transformers-musicgen/run.sh" ENV GALLERIES='[{"name":"model-gallery", "url":"github:go-skynet/model-gallery/index.yaml"}, {"url": "github:go-skynet/model-gallery/huggingface.yaml","name":"huggingface"}]' -ARG GO_TAGS="stablediffusion tts" +ARG GO_TAGS="stablediffusion tinydream tts" RUN apt-get update && \ apt-get install -y ca-certificates curl patch pip cmake && apt-get clean diff --git a/Makefile b/Makefile index 4bb41688..37ac3c8c 100644 --- a/Makefile +++ b/Makefile @@ -33,6 +33,9 @@ PIPER_VERSION?=d6b6275ba037dabdba4a8b65dfdf6b2a73a67f07 # stablediffusion version STABLEDIFFUSION_VERSION?=902db5f066fd137697e3b69d0fa10d4782bd2c2f +# tinydream version +TINYDREAM_VERSION?=772a9c0d9aaf768290e63cca3c904fe69faf677a + export BUILD_TYPE?= export STABLE_BUILD_TYPE?=$(BUILD_TYPE) export CMAKE_ARGS?= @@ -129,6 +132,11 @@ ifeq ($(findstring stablediffusion,$(GO_TAGS)),stablediffusion) OPTIONAL_GRPC+=backend-assets/grpc/stablediffusion endif +ifeq ($(findstring tinydream,$(GO_TAGS)),tinydream) +# OPTIONAL_TARGETS+=go-tiny-dream/libtinydream.a + OPTIONAL_GRPC+=backend-assets/grpc/tinydream +endif + ifeq ($(findstring tts,$(GO_TAGS)),tts) # OPTIONAL_TARGETS+=go-piper/libpiper_binding.a # OPTIONAL_TARGETS+=backend-assets/espeak-ng-data @@ -172,6 +180,14 @@ sources/go-stable-diffusion: sources/go-stable-diffusion/libstablediffusion.a: $(MAKE) -C sources/go-stable-diffusion libstablediffusion.a +## tiny-dream +sources/go-tiny-dream: + git clone --recurse-submodules https://github.com/M0Rf30/go-tiny-dream sources/go-tiny-dream + cd sources/go-tiny-dream && git checkout -b build $(TINYDREAM_VERSION) && git submodule update --init --recursive --depth 1 + +sources/go-tiny-dream/libtinydream.a: + $(MAKE) -C sources/go-tiny-dream libtinydream.a + ## RWKV sources/go-rwkv: git clone --recurse-submodules $(RWKV_REPO) sources/go-rwkv @@ -232,7 +248,7 @@ sources/go-piper/libpiper_binding.a: sources/go-piper backend/cpp/llama/llama.cpp: LLAMA_VERSION=$(CPPLLAMA_VERSION) $(MAKE) -C backend/cpp/llama llama.cpp -get-sources: backend/cpp/llama/llama.cpp sources/go-llama sources/go-llama-ggml sources/go-ggml-transformers sources/gpt4all sources/go-piper sources/go-rwkv sources/whisper.cpp sources/go-bert sources/go-stable-diffusion +get-sources: backend/cpp/llama/llama.cpp sources/go-llama sources/go-llama-ggml sources/go-ggml-transformers sources/gpt4all sources/go-piper sources/go-rwkv sources/whisper.cpp sources/go-bert sources/go-stable-diffusion sources/go-tiny-dream touch $@ replace: @@ -243,6 +259,7 @@ replace: $(GOCMD) mod edit -replace github.com/ggerganov/whisper.cpp/bindings/go=$(shell pwd)/sources/whisper.cpp/bindings/go $(GOCMD) mod edit -replace github.com/go-skynet/go-bert.cpp=$(shell pwd)/sources/go-bert $(GOCMD) mod edit -replace github.com/mudler/go-stable-diffusion=$(shell pwd)/sources/go-stable-diffusion + $(GOCMD) mod edit -replace github.com/M0Rf30/go-tiny-dream=$(shell pwd)/sources/go-tiny-dream $(GOCMD) mod edit -replace github.com/mudler/go-piper=$(shell pwd)/sources/go-piper prepare-sources: get-sources replace @@ -261,6 +278,7 @@ rebuild: ## Rebuilds the project $(MAKE) -C sources/go-stable-diffusion clean $(MAKE) -C sources/go-bert clean $(MAKE) -C sources/go-piper clean + $(MAKE) -C sources/go-tiny-dream clean $(MAKE) build prepare: prepare-sources $(OPTIONAL_TARGETS) @@ -524,9 +542,13 @@ backend-assets/grpc/stablediffusion: backend-assets/grpc if [ ! -f backend-assets/grpc/stablediffusion ]; then \ $(MAKE) sources/go-stable-diffusion/libstablediffusion.a; \ CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/sources/go-stable-diffusion/ LIBRARY_PATH=$(shell pwd)/sources/go-stable-diffusion/ \ - $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/stablediffusion ./backend/go/image/; \ + $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/stablediffusion ./backend/go/image/stablediffusion; \ fi +backend-assets/grpc/tinydream: backend-assets/grpc sources/go-tiny-dream/libtinydream.a + CGO_LDFLAGS="$(CGO_LDFLAGS)" LIBRARY_PATH=$(shell pwd)/go-tiny-dream \ + $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/tinydream ./backend/go/image/tinydream + backend-assets/grpc/piper: backend-assets/grpc backend-assets/espeak-ng-data sources/go-piper/libpiper_binding.a CGO_CXXFLAGS="$(PIPER_CGO_CXXFLAGS)" CGO_LDFLAGS="$(PIPER_CGO_LDFLAGS)" LIBRARY_PATH=$(shell pwd)/sources/go-piper \ $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/piper ./backend/go/tts/ diff --git a/api/openai/image.go b/api/openai/image.go index 8f806275..ba1ba39b 100644 --- a/api/openai/image.go +++ b/api/openai/image.go @@ -122,8 +122,12 @@ func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx log.Debug().Msgf("Parameter Config: %+v", config) - // XXX: Only stablediffusion is supported for now - if config.Backend == "" { + switch config.Backend { + case "stablediffusion": + config.Backend = model.StableDiffusionBackend + case "tinydream": + config.Backend = model.TinyDreamBackend + default: config.Backend = model.StableDiffusionBackend } diff --git a/backend/go/image/main.go b/backend/go/image/stablediffusion/main.go similarity index 81% rename from backend/go/image/main.go rename to backend/go/image/stablediffusion/main.go index 425b2340..07c88500 100644 --- a/backend/go/image/main.go +++ b/backend/go/image/stablediffusion/main.go @@ -15,7 +15,7 @@ var ( func main() { flag.Parse() - if err := grpc.StartServer(*addr, &StableDiffusion{}); err != nil { + if err := grpc.StartServer(*addr, &Image{}); err != nil { panic(err) } } diff --git a/backend/go/image/stablediffusion.go b/backend/go/image/stablediffusion/stablediffusion.go similarity index 71% rename from backend/go/image/stablediffusion.go rename to backend/go/image/stablediffusion/stablediffusion.go index f04ee1de..0f6966f5 100644 --- a/backend/go/image/stablediffusion.go +++ b/backend/go/image/stablediffusion/stablediffusion.go @@ -8,20 +8,20 @@ import ( "github.com/go-skynet/LocalAI/pkg/stablediffusion" ) -type StableDiffusion struct { +type Image struct { base.SingleThread stablediffusion *stablediffusion.StableDiffusion } -func (sd *StableDiffusion) Load(opts *pb.ModelOptions) error { +func (image *Image) Load(opts *pb.ModelOptions) error { var err error // Note: the Model here is a path to a directory containing the model files - sd.stablediffusion, err = stablediffusion.New(opts.ModelFile) + image.stablediffusion, err = stablediffusion.New(opts.ModelFile) return err } -func (sd *StableDiffusion) GenerateImage(opts *pb.GenerateImageRequest) error { - return sd.stablediffusion.GenerateImage( +func (image *Image) GenerateImage(opts *pb.GenerateImageRequest) error { + return image.stablediffusion.GenerateImage( int(opts.Height), int(opts.Width), int(opts.Mode), diff --git a/backend/go/image/tinydream/main.go b/backend/go/image/tinydream/main.go new file mode 100644 index 00000000..07c88500 --- /dev/null +++ b/backend/go/image/tinydream/main.go @@ -0,0 +1,21 @@ +package main + +// Note: this is started internally by LocalAI and a server is allocated for each model + +import ( + "flag" + + grpc "github.com/go-skynet/LocalAI/pkg/grpc" +) + +var ( + addr = flag.String("addr", "localhost:50051", "the address to connect to") +) + +func main() { + flag.Parse() + + if err := grpc.StartServer(*addr, &Image{}); err != nil { + panic(err) + } +} diff --git a/backend/go/image/tinydream/tinydream.go b/backend/go/image/tinydream/tinydream.go new file mode 100644 index 00000000..3dc9d0c0 --- /dev/null +++ b/backend/go/image/tinydream/tinydream.go @@ -0,0 +1,32 @@ +package main + +// This is a wrapper to statisfy the GRPC service interface +// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) +import ( + "github.com/go-skynet/LocalAI/pkg/grpc/base" + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + "github.com/go-skynet/LocalAI/pkg/tinydream" +) + +type Image struct { + base.SingleThread + tinydream *tinydream.TinyDream +} + +func (image *Image) Load(opts *pb.ModelOptions) error { + var err error + // Note: the Model here is a path to a directory containing the model files + image.tinydream, err = tinydream.New(opts.ModelFile) + return err +} + +func (image *Image) GenerateImage(opts *pb.GenerateImageRequest) error { + return image.tinydream.GenerateImage( + int(opts.Height), + int(opts.Width), + int(opts.Step), + int(opts.Seed), + opts.PositivePrompt, + opts.NegativePrompt, + opts.Dst) +} diff --git a/go.mod b/go.mod index 3262a68b..250a2361 100644 --- a/go.mod +++ b/go.mod @@ -3,10 +3,10 @@ module github.com/go-skynet/LocalAI go 1.21 require ( + github.com/M0Rf30/go-tiny-dream v0.0.0-20231128165230-772a9c0d9aaf github.com/donomii/go-rwkv.cpp v0.0.0-20230715075832-c898cd0f62df github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230628193450-85ed71aaec8e github.com/go-audio/wav v1.1.0 - github.com/go-skynet/bloomz.cpp v0.0.0-20230529155654-1834e77b83fa github.com/go-skynet/go-bert.cpp v0.0.0-20230716133540-6abe312cded1 github.com/go-skynet/go-ggml-transformers.cpp v0.0.0-20230714203132-ffb09d7dd71e github.com/go-skynet/go-llama.cpp v0.0.0-20231009155254-aeba71ee8428 @@ -17,7 +17,6 @@ require ( github.com/imdario/mergo v0.3.16 github.com/json-iterator/go v1.1.12 github.com/mholt/archiver/v3 v3.5.1 - github.com/mudler/go-ggllm.cpp v0.0.0-20230709223052-862477d16eef github.com/mudler/go-processmanager v0.0.0-20230818213616-f204007f963c github.com/mudler/go-stable-diffusion v0.0.0-20230605122230-d89260f598af github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20231022042237-c25dc5193530 diff --git a/go.sum b/go.sum index 239bb85d..fc00bf6e 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/M0Rf30/go-tiny-dream v0.0.0-20231128165230-772a9c0d9aaf h1:UgjXLcE9I+VaVz7uBIlzAnyZIXwiDlIiTWqCh159aUI= +github.com/M0Rf30/go-tiny-dream v0.0.0-20231128165230-772a9c0d9aaf/go.mod h1:UOf2Mb/deUri5agct5OJ4SLWjhI+kZKbsUVUeRb24I0= github.com/andybalholm/brotli v1.0.1/go.mod h1:loMXtMfwqflxFJPmdbJO0a3KNoPuLBgiu3qAvBg8x/Y= github.com/andybalholm/brotli v1.0.5 h1:8uQZIdzKmjc/iuPu7O2ioW48L81FgatrcpfFmiq/cCs= github.com/andybalholm/brotli v1.0.5/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= @@ -39,8 +41,6 @@ github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= -github.com/go-skynet/bloomz.cpp v0.0.0-20230529155654-1834e77b83fa h1:gxr68r/6EWroay4iI81jxqGCDbKotY4+CiwdUkBz2NQ= -github.com/go-skynet/bloomz.cpp v0.0.0-20230529155654-1834e77b83fa/go.mod h1:wc0fJ9V04yiYTfgKvE5RUUSRQ5Kzi0Bo4I+U3nNOUuA= github.com/go-skynet/go-bert.cpp v0.0.0-20230716133540-6abe312cded1 h1:yXvc7QfGtoZ51tUW/YVjoTwAfh8HG88XU7UOrbNlz5Y= github.com/go-skynet/go-bert.cpp v0.0.0-20230716133540-6abe312cded1/go.mod h1:fYjkCDRzC+oRLHSjQoajmYK6AmeJnmEanV27CClAcDc= github.com/go-skynet/go-ggml-transformers.cpp v0.0.0-20230714203132-ffb09d7dd71e h1:4reMY29i1eOZaRaSTMPNyXI7X8RMNxCTfDDBXYzrbr0= @@ -71,7 +71,6 @@ github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= @@ -125,18 +124,12 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= -github.com/mudler/go-ggllm.cpp v0.0.0-20230709223052-862477d16eef h1:OJZtJ5vYhlkTJI0RHIl62kOkhiINQEhZgsXlwmmNDhM= -github.com/mudler/go-ggllm.cpp v0.0.0-20230709223052-862477d16eef/go.mod h1:00giAi/vwF8LX29JBjkPQhtASsivPnGNzB6sdmk8JGE= github.com/mudler/go-piper v0.0.0-20230621222733-56b8a81b4760 h1:OFVkSxR7CRSRSNm5dvpMRZwmSwWa8EMMnHbc84fW5tU= github.com/mudler/go-piper v0.0.0-20230621222733-56b8a81b4760/go.mod h1:O7SwdSWMilAWhBZMK9N9Y/oBDyMMzshE3ju8Xkexwig= github.com/mudler/go-processmanager v0.0.0-20230818213616-f204007f963c h1:CI5uGwqBpN8N7BrSKC+nmdfw+9nPQIDyjHHlaIiitZI= github.com/mudler/go-processmanager v0.0.0-20230818213616-f204007f963c/go.mod h1:gY3wyrhkRySJtmtI/JPt4a2mKv48h/M9pEZIW+SjeC0= github.com/mudler/go-stable-diffusion v0.0.0-20230605122230-d89260f598af h1:XFq6OUqsWQam0OrEr05okXsJK/TQur3zoZTHbiZD3Ks= github.com/mudler/go-stable-diffusion v0.0.0-20230605122230-d89260f598af/go.mod h1:8ufRkpz/S/9ahkaxzZ5i4WMgO9w4InEhuRoT7vK5Rnw= -github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20231013181651-22de3c56bdd4 h1:82J4t94Mmt0lva/OoxNlHkKrMSdSUZXkAjTFnlFFsow= -github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20231013181651-22de3c56bdd4/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI= -github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20231016205817-9a19c740ee84 h1:AiFzd+M2Uxz67fdn4nCnKR70me5yf88rXhoqhvfRDak= -github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20231016205817-9a19c740ee84/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI= github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20231022042237-c25dc5193530 h1:YXMxHwHMB9jCBo2Yu5gz3mTB3T1TnZs/HmPLv15LUSA= github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20231022042237-c25dc5193530/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI= github.com/nwaples/rardecode v1.1.0 h1:vSxaY8vQhOcVr4mm5e8XllHWTiM4JF507A0Katqw7MQ= @@ -153,8 +146,6 @@ github.com/onsi/ginkgo/v2 v2.13.0/go.mod h1:TE309ZR8s5FsKKpuB1YAQYBzCaAfUgatB/xl github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= github.com/onsi/gomega v1.16.0/go.mod h1:HnhC7FXeEQY45zxNK3PPoIUhzk/80Xly9PcubAlGdZY= -github.com/onsi/gomega v1.28.0 h1:i2rg/p9n/UqIDAMFUJ6qIUUMcsqOuUHgbpbu235Vr1c= -github.com/onsi/gomega v1.28.0/go.mod h1:A1H2JE76sI14WIP57LMKj7FVfCHx3g3BcZVjJG8bjX8= github.com/onsi/gomega v1.28.1 h1:MijcGUbfYuznzK/5R4CPNoUP/9Xvuo20sXfEm6XxoTA= github.com/onsi/gomega v1.28.1/go.mod h1:9sxs+SwGrKI0+PWe4Fxa9tFQQBG5xSsSbMXOI8PPpoQ= github.com/otiai10/mint v1.6.1 h1:kgbTJmOpp/0ce7hk3H8jiSuR0MXmpwWRfqUdKww17qg= @@ -216,8 +207,6 @@ github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0h github.com/tklauser/numcpus v0.6.0/go.mod h1:FEZLMke0lhOUG6w2JadTzp0a+Nl8PF/GFkQ5UVIcaL4= github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk= github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY= -github.com/tmc/langchaingo v0.0.0-20231016073620-a02d4fdc0f3a h1:BziGpoF5ZVWMDy6Z1adXnYndRye2fiYWZlmknUFksGA= -github.com/tmc/langchaingo v0.0.0-20231016073620-a02d4fdc0f3a/go.mod h1:SiwyRS7sBSSi6f3NB4dKENw69X6br/wZ2WRkM+8pZWk= github.com/tmc/langchaingo v0.0.0-20231019140956-c636b3da7701 h1:LquLgmFiKf6eDXdwoUKCIGn5NsR34cLXC6ySYhiE6bA= github.com/tmc/langchaingo v0.0.0-20231019140956-c636b3da7701/go.mod h1:SiwyRS7sBSSi6f3NB4dKENw69X6br/wZ2WRkM+8pZWk= github.com/ulikunitz/xz v0.5.8/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14= @@ -309,12 +298,8 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98 h1:bVf09lpb+OJbByTj913DRJioFFAjf/ZGxEz7MajTp2U= -google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98/go.mod h1:TUfxEVdsvPg18p6AslUXFoLdpED4oBnGwyqk3dV1XzM= google.golang.org/genproto/googleapis/rpc v0.0.0-20230822172742-b8732ec3820d h1:uvYuEyMHKNt+lT4K3bN6fGswmK8qSvcreM3BwjDh+y4= google.golang.org/genproto/googleapis/rpc v0.0.0-20230822172742-b8732ec3820d/go.mod h1:+Bk1OCOj40wS2hwAMA+aCW9ypzm63QTBBHp6lQ3p+9M= -google.golang.org/grpc v1.58.3 h1:BjnpXut1btbtgN/6sp+brB2Kbm2LjNXnidYujAVbSoQ= -google.golang.org/grpc v1.58.3/go.mod h1:tgX3ZQDlNJGU96V6yHh1T/JeoBQ2TXdr43YbYSsCJk0= google.golang.org/grpc v1.59.0 h1:Z5Iec2pjwb+LEOqzpB2MR12/eKFhDPhuqW91O+4bwUk= google.golang.org/grpc v1.59.0/go.mod h1:aUPDwccQo6OTjy7Hct4AfBPD1GptF4fyUjIkQ9YtF98= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= diff --git a/pkg/model/initializers.go b/pkg/model/initializers.go index 796dc5ae..3195fac9 100644 --- a/pkg/model/initializers.go +++ b/pkg/model/initializers.go @@ -40,6 +40,7 @@ const ( RwkvBackend = "rwkv" WhisperBackend = "whisper" StableDiffusionBackend = "stablediffusion" + TinyDreamBackend = "tinydream" PiperBackend = "piper" LCHuggingFaceBackend = "langchain-huggingface" @@ -64,6 +65,7 @@ var AutoLoadBackends []string = []string{ RwkvBackend, WhisperBackend, StableDiffusionBackend, + TinyDreamBackend, PiperBackend, } diff --git a/pkg/tinydream/generate.go b/pkg/tinydream/generate.go new file mode 100644 index 00000000..cfcd23cc --- /dev/null +++ b/pkg/tinydream/generate.go @@ -0,0 +1,36 @@ +//go:build tinydream +// +build tinydream + +package tinydream + +import ( + "fmt" + "path/filepath" + + tinyDream "github.com/M0Rf30/go-tiny-dream" +) + +func GenerateImage(height, width, step, seed int, positive_prompt, negative_prompt, dst, asset_dir string) error { + fmt.Println(dst) + if height > 512 || width > 512 { + return tinyDream.GenerateImage( + 1, + step, + seed, + positive_prompt, + negative_prompt, + filepath.Dir(dst), + asset_dir, + ) + } + + return tinyDream.GenerateImage( + 0, + step, + seed, + positive_prompt, + negative_prompt, + filepath.Dir(dst), + asset_dir, + ) +} diff --git a/pkg/tinydream/generate_unsupported.go b/pkg/tinydream/generate_unsupported.go new file mode 100644 index 00000000..4ffd421a --- /dev/null +++ b/pkg/tinydream/generate_unsupported.go @@ -0,0 +1,10 @@ +//go:build !tinydream +// +build !tinydream + +package tinydream + +import "fmt" + +func GenerateImage(height, width, step, seed int, positive_prompt, negative_prompt, dst, asset_dir string) error { + return fmt.Errorf("This version of LocalAI was built without the tinytts tag") +} diff --git a/pkg/tinydream/tinydream.go b/pkg/tinydream/tinydream.go new file mode 100644 index 00000000..a316e641 --- /dev/null +++ b/pkg/tinydream/tinydream.go @@ -0,0 +1,20 @@ +package tinydream + +import "os" + +type TinyDream struct { + assetDir string +} + +func New(assetDir string) (*TinyDream, error) { + if _, err := os.Stat(assetDir); err != nil { + return nil, err + } + return &TinyDream{ + assetDir: assetDir, + }, nil +} + +func (td *TinyDream) GenerateImage(height, width, step, seed int, positive_prompt, negative_prompt, dst string) error { + return GenerateImage(height, width, step, seed, positive_prompt, negative_prompt, dst, td.assetDir) +}