From 4d98dd9ce7e4d3a5eb37c3f37fef0a61710beb9e Mon Sep 17 00:00:00 2001 From: Prajwal S Nayak Date: Wed, 29 May 2024 18:10:54 +0530 Subject: [PATCH] feat(image): support `response_type` in the OpenAI API request (#2347) * Change response_format type to string to match OpenAI Spec Signed-off-by: prajwal * updated response_type type to interface Signed-off-by: prajwal * feat: correctly parse generic struct Signed-off-by: mudler * add tests Signed-off-by: mudler --------- Signed-off-by: prajwal Signed-off-by: mudler Co-authored-by: Ettore Di Giacinto Co-authored-by: mudler --- core/config/backend_config.go | 8 +++++--- core/http/endpoints/openai/chat.go | 9 +++++++-- core/http/endpoints/openai/completion.go | 9 +++++++-- core/http/endpoints/openai/image.go | 6 ++---- core/http/endpoints/openai/request.go | 9 +++++++++ core/schema/openai.go | 4 +++- tests/e2e-aio/e2e_test.go | 25 +++++++++++++++++++++++- 7 files changed, 57 insertions(+), 13 deletions(-) diff --git a/core/config/backend_config.go b/core/config/backend_config.go index a4979233..eda66360 100644 --- a/core/config/backend_config.go +++ b/core/config/backend_config.go @@ -27,9 +27,11 @@ type BackendConfig struct { Backend string `yaml:"backend"` TemplateConfig TemplateConfig `yaml:"template"` - PromptStrings, InputStrings []string `yaml:"-"` - InputToken [][]int `yaml:"-"` - functionCallString, functionCallNameString string `yaml:"-"` + PromptStrings, InputStrings []string `yaml:"-"` + InputToken [][]int `yaml:"-"` + functionCallString, functionCallNameString string `yaml:"-"` + ResponseFormat string `yaml:"-"` + ResponseFormatMap map[string]interface{} `yaml:"-"` FunctionsConfig functions.FunctionsConfig `yaml:"function"` diff --git a/core/http/endpoints/openai/chat.go b/core/http/endpoints/openai/chat.go index b2e7aa75..6b4899a5 100644 --- a/core/http/endpoints/openai/chat.go +++ b/core/http/endpoints/openai/chat.go @@ -183,8 +183,13 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup noActionDescription = config.FunctionsConfig.NoActionDescriptionName } - if input.ResponseFormat.Type == "json_object" { - input.Grammar = functions.JSONBNF + if config.ResponseFormatMap != nil { + d := schema.ChatCompletionResponseFormat{} + dat, _ := json.Marshal(config.ResponseFormatMap) + _ = json.Unmarshal(dat, &d) + if d.Type == "json_object" { + input.Grammar = functions.JSONBNF + } } config.Grammar = input.Grammar diff --git a/core/http/endpoints/openai/completion.go b/core/http/endpoints/openai/completion.go index bcd46db5..9554a2dc 100644 --- a/core/http/endpoints/openai/completion.go +++ b/core/http/endpoints/openai/completion.go @@ -69,8 +69,13 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a return fmt.Errorf("failed reading parameters from request:%w", err) } - if input.ResponseFormat.Type == "json_object" { - input.Grammar = functions.JSONBNF + if config.ResponseFormatMap != nil { + d := schema.ChatCompletionResponseFormat{} + dat, _ := json.Marshal(config.ResponseFormatMap) + _ = json.Unmarshal(dat, &d) + if d.Type == "json_object" { + input.Grammar = functions.JSONBNF + } } config.Grammar = input.Grammar diff --git a/core/http/endpoints/openai/image.go b/core/http/endpoints/openai/image.go index 9e806b3e..9de513a4 100644 --- a/core/http/endpoints/openai/image.go +++ b/core/http/endpoints/openai/image.go @@ -149,10 +149,8 @@ func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appCon return fmt.Errorf("invalid value for 'size'") } - b64JSON := false - if input.ResponseFormat.Type == "b64_json" { - b64JSON = true - } + b64JSON := config.ResponseFormat == "b64_json" + // src and clip_skip var result []schema.Item for _, i := range config.PromptStrings { diff --git a/core/http/endpoints/openai/request.go b/core/http/endpoints/openai/request.go index d25e05b5..941a66e3 100644 --- a/core/http/endpoints/openai/request.go +++ b/core/http/endpoints/openai/request.go @@ -129,6 +129,15 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque config.Maxtokens = input.Maxtokens } + if input.ResponseFormat != nil { + switch responseFormat := input.ResponseFormat.(type) { + case string: + config.ResponseFormat = responseFormat + case map[string]interface{}: + config.ResponseFormatMap = responseFormat + } + } + switch stop := input.Stop.(type) { case string: if stop != "" { diff --git a/core/schema/openai.go b/core/schema/openai.go index 177dc7ec..ec8c2c3b 100644 --- a/core/schema/openai.go +++ b/core/schema/openai.go @@ -99,6 +99,8 @@ type OpenAIModel struct { Object string `json:"object"` } +type ImageGenerationResponseFormat string + type ChatCompletionResponseFormatType string type ChatCompletionResponseFormat struct { @@ -114,7 +116,7 @@ type OpenAIRequest struct { // whisper File string `json:"file" validate:"required"` //whisper/image - ResponseFormat ChatCompletionResponseFormat `json:"response_format"` + ResponseFormat interface{} `json:"response_format,omitempty"` // image Size string `json:"size"` // Prompt is read only by completion/image API calls diff --git a/tests/e2e-aio/e2e_test.go b/tests/e2e-aio/e2e_test.go index 8fcd1280..670b3465 100644 --- a/tests/e2e-aio/e2e_test.go +++ b/tests/e2e-aio/e2e_test.go @@ -123,13 +123,36 @@ var _ = Describe("E2E test", func() { openai.ImageRequest{ Prompt: "test", Size: openai.CreateImageSize512x512, - //ResponseFormat: openai.CreateImageResponseFormatURL, }, ) Expect(err).ToNot(HaveOccurred()) Expect(len(resp.Data)).To(Equal(1), fmt.Sprint(resp)) Expect(resp.Data[0].URL).To(ContainSubstring("png"), fmt.Sprint(resp.Data[0].URL)) }) + It("correctly changes the response format to url", func() { + resp, err := client.CreateImage(context.TODO(), + openai.ImageRequest{ + Prompt: "test", + Size: openai.CreateImageSize512x512, + ResponseFormat: openai.CreateImageResponseFormatURL, + }, + ) + Expect(err).ToNot(HaveOccurred()) + Expect(len(resp.Data)).To(Equal(1), fmt.Sprint(resp)) + Expect(resp.Data[0].URL).To(ContainSubstring("png"), fmt.Sprint(resp.Data[0].URL)) + }) + It("correctly changes the response format to base64", func() { + resp, err := client.CreateImage(context.TODO(), + openai.ImageRequest{ + Prompt: "test", + Size: openai.CreateImageSize512x512, + ResponseFormat: openai.CreateImageResponseFormatB64JSON, + }, + ) + Expect(err).ToNot(HaveOccurred()) + Expect(len(resp.Data)).To(Equal(1), fmt.Sprint(resp)) + Expect(resp.Data[0].B64JSON).ToNot(BeEmpty(), fmt.Sprint(resp.Data[0].B64JSON)) + }) }) Context("embeddings", func() { It("correctly", func() {