Add --allow-all-origins arg

This commit is contained in:
Tom Wright
2022-09-29 16:29:44 +01:00
parent 33f80ef0d8
commit 849a3a2f0a
4 changed files with 25 additions and 8 deletions

View File

@@ -14,11 +14,15 @@ import (
)
// NewHTTPRunner returns a grace runner that runs a HTTP server.
func NewHTTPRunner(generator Generator) grace.Runner {
func NewHTTPRunner(generator Generator, allowAllOrigins bool) grace.Runner {
httpHandler := generateHTTPHandler(generator)
if allowAllOrigins {
httpHandler = allowAllOriginsMiddleware(httpHandler)
}
r := http.NewServeMux()
r.Handle("/generate", http.HandlerFunc(httpHandler))
r.Handle("/generate", httpHandler)
return &gracehttpserverrunner.HTTPServerRunner{
Server: &http.Server{
@@ -29,6 +33,18 @@ func NewHTTPRunner(generator Generator) grace.Runner {
}
}
// allowAllOriginsMiddleware sets appropriate CORS headers to allow requests from any origin.
func allowAllOriginsMiddleware(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
if origin == "" {
origin = "*"
}
w.Header().Set("Access-Control-Allow-Origin", origin)
h.ServeHTTP(w, r)
})
}
func writeJSON(rw http.ResponseWriter, value interface{}, status int) {
bytes, err := json.Marshal(value)
if err != nil {
@@ -105,8 +121,8 @@ func getDiagramFromPOST(r *http.Request, imgType string) (*Diagram, error) {
const URLParamImageType = "type"
// generateHTTPHandler returns a HTTP handler used to generate a diagram.
func generateHTTPHandler(generator Generator) func(rw http.ResponseWriter, r *http.Request) {
return func(rw http.ResponseWriter, r *http.Request) {
func generateHTTPHandler(generator Generator) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
var diagram *Diagram
imgType := r.URL.Query().Get(URLParamImageType)
@@ -155,5 +171,5 @@ func generateHTTPHandler(generator Generator) func(rw http.ResponseWriter, r *ht
if err := writeImage(rw, diagramBytes, http.StatusOK, imgType); err != nil {
writeErr(rw, fmt.Errorf("could not write diagram: %w", err), http.StatusInternalServerError)
}
}
})
}