From 849a3a2f0aa37bb769e5139301f325e1bc273bf4 Mon Sep 17 00:00:00 2001 From: Tom Wright Date: Thu, 29 Sep 2022 16:29:44 +0100 Subject: [PATCH] Add --allow-all-origins arg --- Dockerfile | 2 +- README.md | 2 +- cmd/app/main.go | 3 ++- internal/http.go | 26 +++++++++++++++++++++----- 4 files changed, 25 insertions(+), 8 deletions(-) diff --git a/Dockerfile b/Dockerfile index 30fa75d..a45167b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -74,5 +74,5 @@ RUN mkdir -p ./out RUN chmod 0777 ./in RUN chmod 0777 ./out -CMD ["./app", "--mermaid=./node_modules/.bin/mmdc", "--in=./in", "--out=./out", "--puppeteer=./puppeteer-config.json"] +CMD ["./app", "--mermaid=./node_modules/.bin/mmdc", "--in=./in", "--out=./out", "--puppeteer=./puppeteer-config.json", "--allow-all-origins=true"] diff --git a/README.md b/README.md index 7cdc3fe..4083585 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ While this currently serves the diagrams via HTTP, it could easily be manipulate Run the container: ``` -docker run -d --name mermaid-server -p 80:80 tomwright/mermaid-server:latest +docker run -d --name mermaid-server -p 80:80 tomwright/mermaid-server:latest --allow-all-origins=true ``` ### Manually as a go command diff --git a/cmd/app/main.go b/cmd/app/main.go index 13a4026..5103e5c 100644 --- a/cmd/app/main.go +++ b/cmd/app/main.go @@ -14,6 +14,7 @@ func main() { in := flag.String("in", "", "Directory to store input files.") out := flag.String("out", "", "Directory to store output files.") puppeteer := flag.String("puppeteer", "", "Full path to optional puppeteer config.") + allowAllOrigins := flag.Bool("allow-all-origins", false, "True to allow all request origins") flag.Parse() if *mermaid == "" { @@ -36,7 +37,7 @@ func main() { cache := internal.NewDiagramCache() generator := internal.NewGenerator(cache, *mermaid, *in, *out, *puppeteer) - httpRunner := internal.NewHTTPRunner(generator) + httpRunner := internal.NewHTTPRunner(generator, *allowAllOrigins) cleanupRunner := internal.NewCleanupRunner(generator) g.Run(httpRunner) diff --git a/internal/http.go b/internal/http.go index 46073c4..efeca61 100644 --- a/internal/http.go +++ b/internal/http.go @@ -14,11 +14,15 @@ import ( ) // NewHTTPRunner returns a grace runner that runs a HTTP server. -func NewHTTPRunner(generator Generator) grace.Runner { +func NewHTTPRunner(generator Generator, allowAllOrigins bool) grace.Runner { httpHandler := generateHTTPHandler(generator) + if allowAllOrigins { + httpHandler = allowAllOriginsMiddleware(httpHandler) + } + r := http.NewServeMux() - r.Handle("/generate", http.HandlerFunc(httpHandler)) + r.Handle("/generate", httpHandler) return &gracehttpserverrunner.HTTPServerRunner{ Server: &http.Server{ @@ -29,6 +33,18 @@ func NewHTTPRunner(generator Generator) grace.Runner { } } +// allowAllOriginsMiddleware sets appropriate CORS headers to allow requests from any origin. +func allowAllOriginsMiddleware(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + origin := r.Header.Get("Origin") + if origin == "" { + origin = "*" + } + w.Header().Set("Access-Control-Allow-Origin", origin) + h.ServeHTTP(w, r) + }) +} + func writeJSON(rw http.ResponseWriter, value interface{}, status int) { bytes, err := json.Marshal(value) if err != nil { @@ -105,8 +121,8 @@ func getDiagramFromPOST(r *http.Request, imgType string) (*Diagram, error) { const URLParamImageType = "type" // generateHTTPHandler returns a HTTP handler used to generate a diagram. -func generateHTTPHandler(generator Generator) func(rw http.ResponseWriter, r *http.Request) { - return func(rw http.ResponseWriter, r *http.Request) { +func generateHTTPHandler(generator Generator) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { var diagram *Diagram imgType := r.URL.Query().Get(URLParamImageType) @@ -155,5 +171,5 @@ func generateHTTPHandler(generator Generator) func(rw http.ResponseWriter, r *ht if err := writeImage(rw, diagramBytes, http.StatusOK, imgType); err != nil { writeErr(rw, fmt.Errorf("could not write diagram: %w", err), http.StatusInternalServerError) } - } + }) }