diff --git a/tools/rest/multi_binder.go b/tools/rest/multi_binder.go index 5467909a..cd8f52ca 100644 --- a/tools/rest/multi_binder.go +++ b/tools/rest/multi_binder.go @@ -5,9 +5,11 @@ import ( "encoding/json" "io" "net/http" + "reflect" "strings" "github.com/labstack/echo/v5" + "github.com/spf13/cast" ) // BindBody binds request body content to i. @@ -28,10 +30,12 @@ func BindBody(c echo.Context, i interface{}) error { return echo.NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error()) } return nil - default: - // fallback to the default binder - return echo.BindBody(c, i) + case strings.HasPrefix(ctype, echo.MIMEApplicationForm), strings.HasPrefix(ctype, echo.MIMEMultipartForm): + return bindFormData(c, i) } + + // fallback to the default binder + return echo.BindBody(c, i) } // CopyJsonBody reads the request body into i by @@ -57,3 +61,69 @@ func CopyJsonBody(r *http.Request, i interface{}) error { return err } + +// This is temp hotfix for properly binding multipart/form-data array values +// when a map destination is used. +// +// It should be replaced with echo.BindBody(c, i) once the issue is fixed in echo. +func bindFormData(c echo.Context, i interface{}) error { + if i == nil { + return nil + } + + values, err := c.FormValues() + if err != nil { + return echo.NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error()) + } + + if len(values) == 0 { + return nil + } + + rt := reflect.TypeOf(i).Elem() + + // map + if rt.Kind() == reflect.Map { + rv := reflect.ValueOf(i).Elem() + + for k, v := range values { + if total := len(v); total == 1 { + rv.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(normalizeMultipartValue(v[0]))) + } else { + normalized := make([]any, total) + for i, vItem := range v { + normalized[i] = vItem + } + rv.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(normalized)) + } + } + + return nil + } + + // anything else + return echo.BindBody(c, i) +} + +// In order to support more seamlessly both json and multipart/form-data requests, +// the following normalization rules are applied for plain multipart string values: +// - "true" is converted to the json `true` +// - "false" is converted to the json `false` +// - numeric (non-scientific) strings are converted to json number +// - any other string (empty string too) is left as it is +func normalizeMultipartValue(raw string) any { + switch raw { + case "true": + return true + case "false": + return false + default: + if raw[0] >= '0' && raw[0] <= '9' { + if v, err := cast.ToFloat64E(raw); err == nil { + return v + } + } + + return raw + } +} diff --git a/tools/rest/multi_binder_test.go b/tools/rest/multi_binder_test.go index 853e2191..318e767b 100644 --- a/tools/rest/multi_binder_test.go +++ b/tools/rest/multi_binder_test.go @@ -1,6 +1,7 @@ package rest_test import ( + "encoding/json" "io" "net/http" "net/http/httptest" @@ -16,31 +17,40 @@ func TestBindBody(t *testing.T) { scenarios := []struct { body io.Reader contentType string - result map[string]string + expectBody string expectError bool }{ { strings.NewReader(""), echo.MIMEApplicationJSON, - map[string]string{}, + `{}`, false, }, { strings.NewReader(`{"test":"invalid`), echo.MIMEApplicationJSON, - map[string]string{}, + `{}`, true, }, { - strings.NewReader(`{"test":"test123"}`), + strings.NewReader(`{"test":123}`), echo.MIMEApplicationJSON, - map[string]string{"test": "test123"}, + `{"test":123}`, false, }, { - strings.NewReader(url.Values{"test": []string{"test123"}}.Encode()), + strings.NewReader( + url.Values{ + "string": []string{"str"}, + "stings": []string{"str1", "str2"}, + "number": []string{"123"}, + "numbers": []string{"123", "456"}, + "bool": []string{"true"}, + "bools": []string{"true", "false"}, + }.Encode(), + ), echo.MIMEApplicationForm, - map[string]string{"test": "test123"}, + `{"bool":true,"bools":["true","false"],"number":123,"numbers":["123","456"],"stings":["str1","str2"],"string":"str"}`, false, }, } @@ -52,25 +62,22 @@ func TestBindBody(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) - result := map[string]string{} - err := rest.BindBody(c, &result) + data := map[string]any{} + err := rest.BindBody(c, &data) - if err == nil && scenario.expectError { - t.Errorf("(%d) Expected error, got nil", i) + hasErr := err != nil + if hasErr != scenario.expectError { + t.Errorf("[%d] Expected hasErr %v, got %v", i, scenario.expectError, hasErr) } - if err != nil && !scenario.expectError { - t.Errorf("(%d) Expected nil, got error %v", i, err) + rawBody, err := json.Marshal(data) + if err != nil { + t.Errorf("[%d] Failed to marshal binded body: %v", i, err) + } - if len(result) != len(scenario.result) { - t.Errorf("(%d) Expected %v, got %v", i, scenario.result, result) - } - - for k, v := range result { - if sv, ok := scenario.result[k]; !ok || v != sv { - t.Errorf("(%d) Expected value %v for key %s, got %v", i, sv, k, v) - } + if scenario.expectBody != string(rawBody) { + t.Errorf("[%d] Expected body \n%s, \ngot \n%s", i, scenario.expectBody, rawBody) } } }