Merge pull request #99 from TomWright/cors

Add --allow-all-origins arg
This commit is contained in:
Tom Wright 2022-09-29 16:36:41 +01:00 committed by GitHub
commit f37d02c6a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 25 additions and 8 deletions

View File

@ -74,5 +74,5 @@ RUN mkdir -p ./out
RUN chmod 0777 ./in RUN chmod 0777 ./in
RUN chmod 0777 ./out RUN chmod 0777 ./out
CMD ["./app", "--mermaid=./node_modules/.bin/mmdc", "--in=./in", "--out=./out", "--puppeteer=./puppeteer-config.json"] CMD ["./app", "--mermaid=./node_modules/.bin/mmdc", "--in=./in", "--out=./out", "--puppeteer=./puppeteer-config.json", "--allow-all-origins=true"]

View File

@ -10,7 +10,7 @@ While this currently serves the diagrams via HTTP, it could easily be manipulate
Run the container: Run the container:
``` ```
docker run -d --name mermaid-server -p 80:80 tomwright/mermaid-server:latest docker run -d --name mermaid-server -p 80:80 tomwright/mermaid-server:latest --allow-all-origins=true
``` ```
### Manually as a go command ### Manually as a go command

View File

@ -14,6 +14,7 @@ func main() {
in := flag.String("in", "", "Directory to store input files.") in := flag.String("in", "", "Directory to store input files.")
out := flag.String("out", "", "Directory to store output files.") out := flag.String("out", "", "Directory to store output files.")
puppeteer := flag.String("puppeteer", "", "Full path to optional puppeteer config.") puppeteer := flag.String("puppeteer", "", "Full path to optional puppeteer config.")
allowAllOrigins := flag.Bool("allow-all-origins", false, "True to allow all request origins")
flag.Parse() flag.Parse()
if *mermaid == "" { if *mermaid == "" {
@ -36,7 +37,7 @@ func main() {
cache := internal.NewDiagramCache() cache := internal.NewDiagramCache()
generator := internal.NewGenerator(cache, *mermaid, *in, *out, *puppeteer) generator := internal.NewGenerator(cache, *mermaid, *in, *out, *puppeteer)
httpRunner := internal.NewHTTPRunner(generator) httpRunner := internal.NewHTTPRunner(generator, *allowAllOrigins)
cleanupRunner := internal.NewCleanupRunner(generator) cleanupRunner := internal.NewCleanupRunner(generator)
g.Run(httpRunner) g.Run(httpRunner)

View File

@ -14,11 +14,15 @@ import (
) )
// NewHTTPRunner returns a grace runner that runs a HTTP server. // NewHTTPRunner returns a grace runner that runs a HTTP server.
func NewHTTPRunner(generator Generator) grace.Runner { func NewHTTPRunner(generator Generator, allowAllOrigins bool) grace.Runner {
httpHandler := generateHTTPHandler(generator) httpHandler := generateHTTPHandler(generator)
if allowAllOrigins {
httpHandler = allowAllOriginsMiddleware(httpHandler)
}
r := http.NewServeMux() r := http.NewServeMux()
r.Handle("/generate", http.HandlerFunc(httpHandler)) r.Handle("/generate", httpHandler)
return &gracehttpserverrunner.HTTPServerRunner{ return &gracehttpserverrunner.HTTPServerRunner{
Server: &http.Server{ Server: &http.Server{
@ -29,6 +33,18 @@ func NewHTTPRunner(generator Generator) grace.Runner {
} }
} }
// allowAllOriginsMiddleware sets appropriate CORS headers to allow requests from any origin.
func allowAllOriginsMiddleware(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
if origin == "" {
origin = "*"
}
w.Header().Set("Access-Control-Allow-Origin", origin)
h.ServeHTTP(w, r)
})
}
func writeJSON(rw http.ResponseWriter, value interface{}, status int) { func writeJSON(rw http.ResponseWriter, value interface{}, status int) {
bytes, err := json.Marshal(value) bytes, err := json.Marshal(value)
if err != nil { if err != nil {
@ -105,8 +121,8 @@ func getDiagramFromPOST(r *http.Request, imgType string) (*Diagram, error) {
const URLParamImageType = "type" const URLParamImageType = "type"
// generateHTTPHandler returns a HTTP handler used to generate a diagram. // generateHTTPHandler returns a HTTP handler used to generate a diagram.
func generateHTTPHandler(generator Generator) func(rw http.ResponseWriter, r *http.Request) { func generateHTTPHandler(generator Generator) http.Handler {
return func(rw http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
var diagram *Diagram var diagram *Diagram
imgType := r.URL.Query().Get(URLParamImageType) imgType := r.URL.Query().Get(URLParamImageType)
@ -155,5 +171,5 @@ func generateHTTPHandler(generator Generator) func(rw http.ResponseWriter, r *ht
if err := writeImage(rw, diagramBytes, http.StatusOK, imgType); err != nil { if err := writeImage(rw, diagramBytes, http.StatusOK, imgType); err != nil {
writeErr(rw, fmt.Errorf("could not write diagram: %w", err), http.StatusInternalServerError) writeErr(rw, fmt.Errorf("could not write diagram: %w", err), http.StatusInternalServerError)
} }
} })
} }