1
0
Fork 0
timelinize/tlzapp/server.go
Matthew Holt 3459478438
Let TLZ_ORIGIN override allowed_origins config
Also print allowed origins on server start for debugging.

See #171
2025-11-14 09:33:43 -07:00

275 lines
8.6 KiB
Go

/*
Timelinize
Copyright (c) 2013 Matthew Holt
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published
by the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
// TODO: rename this package to app?
package tlzapp
import (
"fmt"
"io/fs"
"net"
"net/http"
"net/netip"
"net/url"
"os"
"slices"
"strings"
"time"
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
"go.uber.org/zap"
)
type server struct {
app *App
log *zap.Logger
adminLn net.Listener // plaintext, no authentication (loopback-only by default)
httpServer *http.Server
// enforce CORS and prevent DNS rebinding for the unauthenticated admin listener
allowedOrigins []*url.URL
mux *http.ServeMux
staticFiles http.Handler
frontend fs.FS // the file system that static assets are served from
}
func (s server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// don't do any actual handling yet; just set up the request middleware stuff, logging, etc...
start := time.Now()
rec := caddyhttp.NewResponseRecorder(w, nil, nil) // TODO: what other places do we pull in Caddy? Maybe we can strip this down and inline it or something...
w.Header().Set("Server", "Timelinize")
var err error
defer func() {
logFn := s.log.Info
if err != nil || rec.Status() >= lowestErrorStatus {
logFn = s.log.Error
}
// the log message is intentionally specific to bust log sampling here
logFn(r.Method+" "+r.RequestURI,
// TODO: include the listener address used (probably need to use ConnContext or something whatever it's called on the http.Server)
zap.String("method", r.Method),
zap.String("uri", r.RequestURI),
zap.Int("status", rec.Status()),
zap.Int("size", rec.Size()),
zap.Duration("duration", time.Since(start)),
zap.Error(err),
)
}()
// ok, we're all set up, so actual handling can happen now
s.mux.ServeHTTP(rec, r)
}
func (s *server) fillAllowedOrigins(configuredOrigins []string, listenAddr string) {
listenHost, listenPort, err := net.SplitHostPort(listenAddr)
if err != nil {
listenHost = listenAddr // assume no port (or a default port)
}
// the environment variable takes priority over the config file (see #171)
if originEnv := os.Getenv("TLZ_ORIGIN"); originEnv != "" {
origins := strings.Split(originEnv, ",")
configuredOrigins = make([]string, 0, len(origins))
for _, origin := range origins {
if trimmed := strings.TrimSpace(origin); trimmed != "" {
configuredOrigins = append(configuredOrigins, trimmed)
}
}
}
uniqueOrigins := make(map[string]struct{})
for _, o := range configuredOrigins {
uniqueOrigins[o] = struct{}{}
}
if !strings.HasPrefix(listenAddr, "/") {
uniqueOrigins[net.JoinHostPort("localhost", listenPort)] = struct{}{}
uniqueOrigins[net.JoinHostPort("::1", listenPort)] = struct{}{}
uniqueOrigins[net.JoinHostPort("127.0.0.1", listenPort)] = struct{}{}
if listenHost != "" && !s.isLoopback(listenAddr) {
uniqueOrigins[listenAddr] = struct{}{}
}
}
s.allowedOrigins = make([]*url.URL, 0, len(uniqueOrigins))
for originStr := range uniqueOrigins {
var origin *url.URL
if strings.Contains(originStr, "://") {
var err error
origin, err = url.Parse(originStr)
if err != nil {
continue
}
origin.Path = ""
origin.RawPath = ""
origin.Fragment = ""
origin.RawFragment = ""
origin.RawQuery = ""
} else {
origin = &url.URL{Host: originStr}
}
s.allowedOrigins = append(s.allowedOrigins, origin)
}
}
func (server) isLoopback(addr string) bool {
// unix socket
if strings.HasPrefix(addr, "/") {
return true
}
host, _, err := net.SplitHostPort(addr)
if err != nil {
host = addr // assume no port
}
if host == "localhost" {
return true
}
if ip, err := netip.ParseAddr(host); err == nil {
return ip.IsLoopback()
}
return false
}
// enforceHost returns a handler that wraps next such that
// it will only be called if the request's Host header matches
// a trustworthy/expected value. This helps to mitigate DNS
// rebinding attacks.
func (s server) enforceHost(next handler) handler {
return handlerFunc(func(w http.ResponseWriter, r *http.Request) error {
allowed := slices.ContainsFunc(s.allowedOrigins, func(u *url.URL) bool {
return r.Host == u.Host
})
if !allowed {
return Error{
Err: fmt.Errorf("unrecognized Host header value '%s'", r.Host),
HTTPStatus: http.StatusForbidden,
Log: "Host not allowed",
Message: "This endpoint can only be accessed via a trusted host.",
}
}
return next.ServeHTTP(w, r)
})
}
// enforceOriginAndMethod ensures that the Origin header matches the expected value(s),
// sets CORS headers, and also enforces the proper/expected method for the route.
// This prevents arbitrary sites from issuing requests to our listener.
func (s server) enforceOriginAndMethod(method string, next handler) handler {
return handlerFunc(func(w http.ResponseWriter, r *http.Request) error {
origin := s.getOrigin(r)
if origin != nil {
if !s.originAllowed(origin) {
return Error{
Err: fmt.Errorf("unrecognized origin '%s'", origin),
HTTPStatus: http.StatusForbidden,
Log: "Origin not allowed",
Message: "You can only access this API from a recognized origin.",
}
}
if r.Method == http.MethodOptions {
w.Header().Set("Access-Control-Allow-Origin", origin.String())
w.Header().Set("Access-Control-Allow-Methods", "OPTIONS, "+method)
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length")
w.Header().Set("Access-Control-Allow-Credentials", "true")
}
w.Header().Set("Access-Control-Allow-Origin", origin.String())
}
if r.Method == http.MethodOptions {
return nil
}
// method must match, unless GET is expected, in which case HEAD is also allowed
if r.Method != method ||
(method == http.MethodGet &&
r.Method != method &&
r.Method != http.MethodHead) {
// "sir, this is an Arby's"
return Error{
Err: fmt.Errorf("method '%s' not allowed", r.Method),
HTTPStatus: http.StatusMethodNotAllowed,
}
}
return next.ServeHTTP(w, r)
})
}
func (s server) getOrigin(r *http.Request) *url.URL {
origin := r.Header.Get("Origin")
if origin == "" {
// Arggh, Firefox has a bug (Jan. 2022) where it doesn't set the Origin header in
// HTTPS-Only mode, breaking CORS requests. https://bugzilla.mozilla.org/show_bug.cgi?id=1751105
// (Yes, I found and reported the bug after much chin-scratching and head-banging).
// (Was verified on Twitter by Eric Lawrence, one of the Edge engineers: https://twitter.com/ericlaw/status/1483972139241857027)
// Anyway, for now, disable HTTPS-Only mode in Firefox OR add an exception for "http://localhost"
// ... or use the Referer header, I guess.
origin = r.Header.Get("Referer")
}
if origin == "" {
return nil
}
originURL, err := url.Parse(origin)
if err != nil {
return nil
}
originURL.Path = ""
originURL.RawPath = ""
originURL.Fragment = ""
originURL.RawFragment = ""
originURL.RawQuery = ""
return originURL
}
func (s server) originAllowed(origin *url.URL) bool {
for _, allowedOrigin := range s.allowedOrigins {
if allowedOrigin.Scheme != "" && origin.Scheme != allowedOrigin.Scheme {
continue
}
if origin.Host == allowedOrigin.Host {
return true
}
}
return false
}
// TODO: not sure if we'll need this with this program...
// // apiClient is an HTTP client to be used for API requests,
// // considering that some requests are blocking and can take
// // many seconds to finish, like setting repo key. It honors
// // the application's IPv6 configuration setting.
// var apiClient = &http.Client{
// Transport: &http.Transport{
// Proxy: http.ProxyFromEnvironment,
// DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
// dialer := &net.Dialer{
// Timeout: 10 * time.Second,
// KeepAlive: 120 * time.Second,
// }
// if !Config.IPv6 && network == "tcp" {
// return dialer.DialContext(ctx, "tcp4", addr)
// }
// return dialer.DialContext(ctx, network, addr)
// },
// MaxIdleConns: 100,
// IdleConnTimeout: 90 * time.Second,
// TLSHandshakeTimeout: 10 * time.Second,
// ExpectContinueTimeout: 1 * time.Second,
// },
// Timeout: 60 * time.Second,
// }