diff --git a/Earthfile b/Earthfile new file mode 100644 index 00000000..4c13b7b9 --- /dev/null +++ b/Earthfile @@ -0,0 +1,40 @@ +VERSION 0.6 + +ARG GO_VERSION=1.20 +ARG GOLINT_VERSION=1.47.3 + +go-deps: + ARG GO_VERSION + FROM golang:$GO_VERSION + WORKDIR /build + COPY go.mod ./ + COPY go.sum ./ + RUN go mod download + RUN apt-get update + SAVE ARTIFACT go.mod AS LOCAL go.mod + SAVE ARTIFACT go.sum AS LOCAL go.sum + +alpaca-model: + FROM alpine + # This is the alpaca.cpp model https://github.com/antimatter15/alpaca.cpp + ARG MODEL_URL=https://ipfs.io/ipfs/QmQ1bf2BTnYxq73MFJWu1B7bQ2UD6qG7D7YDCxhTndVkPC + RUN wget -O model.bin -c https://ipfs.io/ipfs/QmQ1bf2BTnYxq73MFJWu1B7bQ2UD6qG7D7YDCxhTndVkPC + SAVE ARTIFACT model.bin AS LOCAL model.bin + +build: + FROM +go-deps + WORKDIR /build + RUN git clone https://github.com/go-skynet/llama + RUN cd llama && make libllama.a + COPY . . + RUN C_INCLUDE_PATH=/build/llama LIBRARY_PATH=/build/llama go build -o llama-cli ./ + SAVE ARTIFACT llama-cli AS LOCAL llama-cli + +image: + FROM +go-deps + ARG IMAGE=alpaca-cli + COPY +alpaca-model/model.bin /model.bin + COPY +build/llama-cli /llama-cli + ENV MODEL_PATH=/model.bin + ENTRYPOINT /llama-cli + SAVE IMAGE $IMAGE \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 00000000..b9c46f0a --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 go-skynet authors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/api.go b/api.go new file mode 100644 index 00000000..c493d4df --- /dev/null +++ b/api.go @@ -0,0 +1,83 @@ +package main + +import ( + "strconv" + + llama "github.com/go-skynet/llama/go" + "github.com/gofiber/fiber/v2" +) + +func api(model, listenAddr string, threads int) error { + app := fiber.New() + + l, err := llama.New(model, 0) + if err != nil { + return err + } + + /* + curl --location --request POST 'http://localhost:8080/predict' --header 'Content-Type: application/json' --data-raw '{ + "text": "What is an alpaca?", + "topP": 0.8, + "topK": 50, + "temperature": 0.7, + "tokens": 100 + }' + */ + + // Endpoint to generate the prediction + app.Post("/predict", func(c *fiber.Ctx) error { + // Get input data from the request body + input := new(struct { + Text string `json:"text"` + }) + if err := c.BodyParser(input); err != nil { + return err + } + + // Set the parameters for the language model prediction + topP, err := strconv.ParseFloat(c.Query("topP", "0.9"), 64) // Default value of topP is 0.9 + if err != nil { + return err + } + + topK, err := strconv.Atoi(c.Query("topK", "40")) // Default value of topK is 40 + if err != nil { + return err + } + + temperature, err := strconv.ParseFloat(c.Query("temperature", "0.5"), 64) // Default value of temperature is 0.5 + if err != nil { + return err + } + + tokens, err := strconv.Atoi(c.Query("tokens", "128")) // Default value of tokens is 128 + if err != nil { + return err + } + + // Generate the prediction using the language model + prediction, err := l.Predict( + input.Text, + llama.SetTemperature(temperature), + llama.SetTopP(topP), + llama.SetTopK(topK), + llama.SetTokens(tokens), + llama.SetThreads(threads), + ) + if err != nil { + return err + } + + // Return the prediction in the response body + return c.JSON(struct { + Prediction string `json:"prediction"` + }{ + Prediction: prediction, + }) + }) + + // Start the server + app.Listen(":8080") + return nil +} diff --git a/go.mod b/go.mod new file mode 100644 index 00000000..97aeae26 --- /dev/null +++ b/go.mod @@ -0,0 +1,30 @@ +module github.com/go-skynet/llama-cli + +go 1.19 + +require ( + github.com/go-skynet/llama v0.0.0-20230318101759-56080ad745d1 + github.com/gofiber/fiber/v2 v2.42.0 + github.com/urfave/cli/v2 v2.25.0 +) + +require ( + github.com/andybalholm/brotli v1.0.4 // indirect + github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect + github.com/google/uuid v1.3.0 // indirect + github.com/klauspost/compress v1.15.9 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.17 // indirect + github.com/mattn/go-runewidth v0.0.14 // indirect + github.com/philhofer/fwd v1.1.1 // indirect + github.com/rivo/uniseg v0.2.0 // indirect + github.com/russross/blackfriday/v2 v2.1.0 // indirect + github.com/savsgio/dictpool v0.0.0-20221023140959-7bf2e61cea94 // indirect + github.com/savsgio/gotils v0.0.0-20220530130905-52f3993e8d6d // indirect + github.com/tinylib/msgp v1.1.6 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + github.com/valyala/fasthttp v1.44.0 // indirect + github.com/valyala/tcplisten v1.0.0 // indirect + github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect + golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 00000000..01de1298 --- /dev/null +++ b/go.sum @@ -0,0 +1,75 @@ +github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY= +github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= +github.com/cpuguy83/go-md2man/v2 v2.0.2 h1:p1EgwI/C7NhT0JmVkwCD2ZBK8j4aeHQX2pMHHBfMQ6w= +github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/go-skynet/llama v0.0.0-20230318101759-56080ad745d1 h1:0DwtVqERXmPTnzjv6pOPUgk9rOhC9ipD2Xn/SrW6Iro= +github.com/go-skynet/llama v0.0.0-20230318101759-56080ad745d1/go.mod h1:ZtYsAIud4cvP9VTTI9uhdgR1uCwaO/gGKnZZ95h9i7w= +github.com/gofiber/fiber/v2 v2.42.0 h1:Fnp7ybWvS+sjNQsFvkhf4G8OhXswvB6Vee8hM/LyS+8= +github.com/gofiber/fiber/v2 v2.42.0/go.mod h1:3+SGNjqMh5VQH5Vz2Wdi43zTIV16ktlFd3x3R6O1Zlc= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/klauspost/compress v1.15.9 h1:wKRjX6JRtDdrE9qwa4b/Cip7ACOshUI4smpCQanqjSY= +github.com/klauspost/compress v1.15.9/go.mod h1:PhcZ0MbTNciWF3rruxRgKxI5NkcHHrHUDtV4Yw2GlzU= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.17 h1:BTarxUcIeDqL27Mc+vyvdWYSL28zpIhv3RoTdsLMPng= +github.com/mattn/go-isatty v0.0.17/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU= +github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/philhofer/fwd v1.1.1 h1:GdGcTjf5RNAxwS4QLsiMzJYj5KEvPJD3Abr261yRQXQ= +github.com/philhofer/fwd v1.1.1/go.mod h1:gk3iGcWd9+svBvR0sR+KPcfE+RNWozjowpeBVG3ZVNU= +github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= +github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= +github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/savsgio/dictpool v0.0.0-20221023140959-7bf2e61cea94 h1:rmMl4fXJhKMNWl+K+r/fq4FbbKI+Ia2m9hYBLm2h4G4= +github.com/savsgio/dictpool v0.0.0-20221023140959-7bf2e61cea94/go.mod h1:90zrgN3D/WJsDd1iXHT96alCoN2KJo6/4x1DZC3wZs8= +github.com/savsgio/gotils v0.0.0-20220530130905-52f3993e8d6d h1:Q+gqLBOPkFGHyCJxXMRqtUgUbTjI8/Ze8vu8GGyNFwo= +github.com/savsgio/gotils v0.0.0-20220530130905-52f3993e8d6d/go.mod h1:Gy+0tqhJvgGlqnTF8CVGP0AaGRjwBtXs/a5PA0Y3+A4= +github.com/tinylib/msgp v1.1.6 h1:i+SbKraHhnrf9M5MYmvQhFnbLhAXSDWF8WWsuyRdocw= +github.com/tinylib/msgp v1.1.6/go.mod h1:75BAfg2hauQhs3qedfdDZmWAPcFMAvJE5b9rGOMufyw= +github.com/urfave/cli/v2 v2.25.0 h1:ykdZKuQey2zq0yin/l7JOm9Mh+pg72ngYMeB0ABn6q8= +github.com/urfave/cli/v2 v2.25.0/go.mod h1:GHupkWPMM0M/sj1a2b4wUrWBPzazNrIjouW6fmdJLxc= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasthttp v1.44.0 h1:R+gLUhldIsfg1HokMuQjdQ5bh9nuXHPIfvkYUu9eR5Q= +github.com/valyala/fasthttp v1.44.0/go.mod h1:f6VbjjoI3z1NDOZOv17o6RvtRSWxC77seBFc2uWtgiY= +github.com/valyala/tcplisten v1.0.0 h1:rBHj/Xf+E1tRGZyWIWwJDiRY0zc1Js+CV5DqwacVSA8= +github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc= +github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 h1:bAn7/zixMGCfxrRTfdpNzjtPYqr8smhKouy9mxVdGPU= +github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673/go.mod h1:N3UwUGtsrSj3ccvlPHLoLsHnpR27oXr4ZE984MbSER8= +github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20220214200702-86341886e292/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20220906165146-f3363e06e74c/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab h1:2QkjZIsXupsJbJIdSjjUOgWK3aEtzyuh2mPt3l/CkeU= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20201022035929-9cf592e881e9/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/main.go b/main.go new file mode 100644 index 00000000..a630a12c --- /dev/null +++ b/main.go @@ -0,0 +1,209 @@ +package main + +import ( + "bytes" + "fmt" + "io/ioutil" + "os" + "runtime" + "text/template" + + llama "github.com/go-skynet/llama/go" + "github.com/urfave/cli/v2" +) + +// Define the template string +var emptyInput string = `Below is an instruction that describes a task. Write a response that appropriately completes the request. + +### Instruction: +{{.Instruction}} + +### Response:` + +var nonEmptyInput string = `Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. + +### Instruction: +{{.Instruction}} + +### Input: +{{.Input}} + +### Response: +` + +func templateString(t string, in interface{}) (string, error) { + // Parse the template + tmpl, err := template.New("prompt").Parse(t) + if err != nil { + return "", err + } + + var buf bytes.Buffer + err = tmpl.Execute(&buf, in) + if err != nil { + return "", err + } + return buf.String(), nil +} + +func main() { + app := &cli.App{ + Name: "llama-cli", + Version: "0.1", + Usage: "llama-cli --model ... --instruction 'What is an alpaca?'", + Flags: []cli.Flag{ + &cli.StringFlag{ + Name: "template", + EnvVars: []string{"TEMPLATE"}, + }, + &cli.StringFlag{ + Name: "instruction", + EnvVars: []string{"INSTRUCTION"}, + }, + &cli.StringFlag{ + Name: "input", + EnvVars: []string{"INPUT"}, + }, + &cli.StringFlag{ + Name: "model", + EnvVars: []string{"MODEL_PATH"}, + }, + &cli.IntFlag{ + Name: "tokens", + EnvVars: []string{"TOKENS"}, + Value: 128, + }, + &cli.IntFlag{ + Name: "threads", + EnvVars: []string{"THREADS"}, + Value: runtime.NumCPU(), + }, + &cli.Float64Flag{ + Name: "temperature", + EnvVars: []string{"TEMPERATURE"}, + Value: 0.95, + }, + &cli.Float64Flag{ + Name: "topp", + EnvVars: []string{"TOP_P"}, + Value: 0.85, + }, + &cli.IntFlag{ + Name: "topk", + EnvVars: []string{"TOP_K"}, + Value: 20, + }, + }, + Description: `Run llama.cpp inference`, + UsageText: ` +llama-cli --model ~/ggml-alpaca-7b-q4.bin --instruction "What's an alpaca?" + + An Alpaca (Vicugna pacos) is a domesticated species of South American camelid, related to llamas and originally from Peru but now found throughout much of Andean region. They are bred for their fleeces which can be spun into wool or knitted items such as hats, sweaters, blankets etc + +echo "An Alpaca (Vicugna pacos) is a domesticated species of South American camelid, related to llamas and originally from Peru but now found throughout much of Andean region. They are bred for their fleeces which can be spun into wool or knitted items such as hats, sweaters, blankets etc" | llama-cli --model ~/ggml-alpaca-7b-q4.bin --instruction "Proofread, improving clarity and flow" --input "-" + + An Alpaca (Vicugna pacos) is a domesticated species from South America that's related to llamas. Originating in Peru but now found throughout the Andean region, they are bred for their fleeces which can be spun into wool or knitted items such as hats and sweaters—blankets too! +`, + Copyright: "go-skynet authors", + Commands: []*cli.Command{ + { + Name: "api", + Flags: []cli.Flag{ + &cli.IntFlag{ + Name: "threads", + EnvVars: []string{"THREADS"}, + Value: runtime.NumCPU(), + }, + &cli.StringFlag{ + Name: "model", + EnvVars: []string{"MODEL_PATH"}, + }, + &cli.StringFlag{ + Name: "address", + EnvVars: []string{"ADDRESS"}, + Value: ":8080", + }, + }, + Action: func(ctx *cli.Context) error { + return api(ctx.String("model"), ctx.String("address"), ctx.Int("threads")) + }, + }, + }, + Action: func(ctx *cli.Context) error { + + instruction := ctx.String("instruction") + input := ctx.String("input") + templ := ctx.String("template") + + promptTemplate := "" + + if input != "" { + promptTemplate = nonEmptyInput + } else { + promptTemplate = emptyInput + } + + if templ != "" { + dat, err := os.ReadFile(templ) + if err != nil { + fmt.Printf("Failed reading file: %s", err.Error()) + os.Exit(1) + } + promptTemplate = string(dat) + } + + if instruction == "-" { + dat, err := ioutil.ReadAll(os.Stdin) + if err != nil { + fmt.Printf("reading stdin failed: %s", err) + os.Exit(1) + } + instruction = string(dat) + } + + if input == "-" { + dat, err := ioutil.ReadAll(os.Stdin) + if err != nil { + fmt.Printf("reading stdin failed: %s", err) + os.Exit(1) + } + input = string(dat) + } + + str, err := templateString(promptTemplate, struct { + Instruction string + Input string + }{Instruction: instruction, Input: input}) + + if err != nil { + fmt.Println("Templating the input failed:", err.Error()) + os.Exit(1) + } + l, err := llama.New(ctx.String("model"), 0) + if err != nil { + fmt.Println("Loading the model failed:", err.Error()) + os.Exit(1) + } + res, err := l.Predict( + str, + llama.SetTemperature(ctx.Float64("temperature")), + llama.SetTopP(ctx.Float64("topp")), + llama.SetTopK(ctx.Int("topk")), + llama.SetTokens(ctx.Int("tokens")), + llama.SetThreads(ctx.Int("threads")), + ) + if err != nil { + fmt.Printf("predicting failed: %s", err) + os.Exit(1) + } + fmt.Println(res) + return nil + }, + } + + err := app.Run(os.Args) + if err != nil { + fmt.Println(err) + os.Exit(1) + } +}