Merge remote-tracking branch 'remotes/upstream/master' into patch-listen-tls

This commit is contained in:
Sergey Krashevich
2023-05-21 23:29:08 +03:00
135 changed files with 5094 additions and 1568 deletions
+205
View File
@@ -0,0 +1,205 @@
package api
import (
"crypto/tls"
"encoding/json"
"github.com/AlexxIT/go2rtc/internal/app"
"github.com/rs/zerolog"
"net"
"net/http"
"os"
"strconv"
"strings"
"sync"
)
func Init() {
var cfg struct {
Mod struct {
Listen string `yaml:"listen"`
Username string `yaml:"username"`
Password string `yaml:"password"`
BasePath string `yaml:"base_path"`
StaticDir string `yaml:"static_dir"`
Origin string `yaml:"origin"`
TLSListen string `yaml:"tls_listen"`
TLSCert string `yaml:"tls_cert"`
TLSPrivateKey string `yaml:"tls_private_key"`
} `yaml:"api"`
}
// default config
cfg.Mod.Listen = ":1984"
// load config from YAML
app.LoadConfig(&cfg)
if cfg.Mod.Listen == "" {
return
}
basePath = cfg.Mod.BasePath
log = app.GetLogger("api")
initStatic(cfg.Mod.StaticDir)
initWS(cfg.Mod.Origin)
HandleFunc("api", apiHandler)
HandleFunc("api/config", configHandler)
HandleFunc("api/exit", exitHandler)
HandleFunc("api/ws", apiWS)
// ensure we can listen without errors
listener, err := net.Listen("tcp", cfg.Mod.Listen)
if err != nil {
log.Fatal().Err(err).Msg("[api] listen")
return
}
log.Info().Str("addr", cfg.Mod.Listen).Msg("[api] listen")
Handler = http.DefaultServeMux // 4th
if cfg.Mod.Origin == "*" {
Handler = middlewareCORS(Handler) // 3rd
}
if cfg.Mod.Username != "" {
Handler = middlewareAuth(cfg.Mod.Username, cfg.Mod.Password, Handler) // 2nd
}
if log.Trace().Enabled() {
Handler = middlewareLog(Handler) // 1st
}
go func() {
s := http.Server{}
s.Handler = Handler
if err = s.Serve(listener); err != nil {
log.Fatal().Err(err).Msg("[api] serve")
}
}()
// Initialize the HTTPS server
if cfg.Mod.TLSListen != "" {
tlsConfig := &tls.Config{}
if cfg.Mod.TLSCert != "" && cfg.Mod.TLSPrivateKey != "" {
tlsListener, err := net.Listen("tcp", cfg.Mod.TLSListen)
if err != nil {
log.Fatal().Err(err).Msg("[api] tls listen")
return
}
log.Info().Str("addr", cfg.Mod.TLSListen).Msg("[api] tls listen")
cert, err := tls.X509KeyPair([]byte(cfg.Mod.TLSCert), []byte(cfg.Mod.TLSPrivateKey))
if err != nil {
print(cfg.Mod.TLSCert)
log.Fatal().Err(err).Msg("[api] tls load cert/key")
return
}
tlsConfig.Certificates = []tls.Certificate{cert}
tlsServer := &http.Server{
Handler: Handler,
TLSConfig: tlsConfig,
}
go func() {
if err := tlsServer.ServeTLS(tlsListener, "", ""); err != nil {
log.Fatal().Err(err).Msg("[api] tls serve")
}
}()
}
}
}
var Handler http.Handler
// HandleFunc handle pattern with relative path:
// - "api/streams" => "{basepath}/api/streams"
// - "/streams" => "/streams"
func HandleFunc(pattern string, handler http.HandlerFunc) {
if len(pattern) == 0 || pattern[0] != '/' {
pattern = basePath + "/" + pattern
}
log.Trace().Str("path", pattern).Msg("[api] register path")
http.HandleFunc(pattern, handler)
}
const StreamNotFound = "stream not found"
var basePath string
var log zerolog.Logger
func middlewareLog(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Trace().Msgf("[api] %s %s %s", r.Method, r.URL, r.RemoteAddr)
next.ServeHTTP(w, r)
})
}
func middlewareAuth(username, password string, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !strings.HasPrefix(r.RemoteAddr, "127.") && !strings.HasPrefix(r.RemoteAddr, "[::1]") {
user, pass, ok := r.BasicAuth()
if !ok || user != username || pass != password {
w.Header().Set("Www-Authenticate", `Basic realm="go2rtc"`)
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
}
next.ServeHTTP(w, r)
})
}
func middlewareCORS(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Authorization")
next.ServeHTTP(w, r)
})
}
var mu sync.Mutex
func apiHandler(w http.ResponseWriter, r *http.Request) {
mu.Lock()
app.Info["host"] = r.Host
mu.Unlock()
if err := json.NewEncoder(w).Encode(app.Info); err != nil {
log.Warn().Err(err).Caller().Send()
}
}
func exitHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
http.Error(w, "", http.StatusBadRequest)
return
}
s := r.URL.Query().Get("code")
code, _ := strconv.Atoi(s)
os.Exit(code)
}
type Stream struct {
Name string `json:"name"`
URL string `json:"url"`
}
func ResponseStreams(w http.ResponseWriter, streams []Stream) {
if len(streams) == 0 {
http.Error(w, "no streams", http.StatusNotFound)
return
}
var response struct {
Streams []Stream `json:"streams"`
}
response.Streams = streams
if err := json.NewEncoder(w).Encode(response); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}
+102
View File
@@ -0,0 +1,102 @@
package api
import (
"github.com/AlexxIT/go2rtc/internal/app"
"gopkg.in/yaml.v3"
"io"
"net/http"
"os"
)
func configHandler(w http.ResponseWriter, r *http.Request) {
if app.ConfigPath == "" {
http.Error(w, "", http.StatusGone)
return
}
switch r.Method {
case "GET":
data, err := os.ReadFile(app.ConfigPath)
if err != nil {
http.Error(w, "", http.StatusNotFound)
return
}
if _, err = w.Write(data); err != nil {
log.Warn().Err(err).Caller().Send()
}
case "POST", "PATCH":
data, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if r.Method == "PATCH" {
// no need to validate after merge
data, err = mergeYAML(app.ConfigPath, data)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
} else {
// validate config
var tmp struct{}
if err = yaml.Unmarshal(data, &tmp); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}
if err = os.WriteFile(app.ConfigPath, data, 0644); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}
}
func mergeYAML(file1 string, yaml2 []byte) ([]byte, error) {
// Read the contents of the first YAML file
data1, err := os.ReadFile(file1)
if err != nil {
return nil, err
}
// Unmarshal the first YAML file into a map
var config1 map[string]any
if err = yaml.Unmarshal(data1, &config1); err != nil {
return nil, err
}
// Unmarshal the second YAML document into a map
var config2 map[string]any
if err = yaml.Unmarshal(yaml2, &config2); err != nil {
return nil, err
}
// Merge the two maps
config1 = merge(config1, config2)
// Marshal the merged map into YAML
return yaml.Marshal(&config1)
}
func merge(dst, src map[string]any) map[string]any {
for k, v := range src {
if vv, ok := dst[k]; ok {
switch vv := vv.(type) {
case map[string]any:
v := v.(map[string]any)
dst[k] = merge(vv, v)
case []any:
v := v.([]any)
dst[k] = v
default:
dst[k] = v
}
} else {
dst[k] = v
}
}
return dst
}
+26
View File
@@ -0,0 +1,26 @@
package api
import (
"github.com/AlexxIT/go2rtc/www"
"net/http"
)
func initStatic(staticDir string) {
var root http.FileSystem
if staticDir != "" {
log.Info().Str("dir", staticDir).Msg("[api] serve static")
root = http.Dir(staticDir)
} else {
root = http.FS(www.Static)
}
base := len(basePath)
fileServer := http.FileServer(root)
HandleFunc("", func(w http.ResponseWriter, r *http.Request) {
if base > 0 {
r.URL.Path = r.URL.Path[base:]
}
fileServer.ServeHTTP(w, r)
})
}
+185
View File
@@ -0,0 +1,185 @@
package api
import (
"github.com/gorilla/websocket"
"net/http"
"net/url"
"strings"
"sync"
"time"
)
// Message - struct for data exchange in Web API
type Message struct {
Type string `json:"type"`
Value any `json:"value,omitempty"`
}
func (m *Message) String() string {
if s, ok := m.Value.(string); ok {
return s
}
return ""
}
func (m *Message) GetString(key string) string {
if v, ok := m.Value.(map[string]any); ok {
if s, ok := v[key].(string); ok {
return s
}
}
return ""
}
type WSHandler func(tr *Transport, msg *Message) error
func HandleWS(msgType string, handler WSHandler) {
wsHandlers[msgType] = handler
}
var wsHandlers = make(map[string]WSHandler)
func initWS(origin string) {
wsUp = &websocket.Upgrader{
ReadBufferSize: 4096, // for SDP
WriteBufferSize: 512 * 1024, // 512K
}
switch origin {
case "":
// same origin + ignore port
wsUp.CheckOrigin = func(r *http.Request) bool {
origin := r.Header["Origin"]
if len(origin) == 0 {
return true
}
o, err := url.Parse(origin[0])
if err != nil {
return false
}
if o.Host == r.Host {
return true
}
log.Trace().Msgf("[api.ws] origin=%s, host=%s", o.Host, r.Host)
// https://github.com/AlexxIT/go2rtc/issues/118
if i := strings.IndexByte(o.Host, ':'); i > 0 {
return o.Host[:i] == r.Host
}
return false
}
case "*":
// any origin
wsUp.CheckOrigin = func(r *http.Request) bool {
return true
}
}
}
func apiWS(w http.ResponseWriter, r *http.Request) {
ws, err := wsUp.Upgrade(w, r, nil)
if err != nil {
origin := r.Header.Get("Origin")
log.Error().Err(err).Caller().Msgf("host=%s origin=%s", r.Host, origin)
return
}
tr := &Transport{Request: r}
tr.OnWrite(func(msg any) {
_ = ws.SetWriteDeadline(time.Now().Add(time.Second * 5))
if data, ok := msg.([]byte); ok {
_ = ws.WriteMessage(websocket.BinaryMessage, data)
} else {
_ = ws.WriteJSON(msg)
}
})
for {
msg := new(Message)
if err = ws.ReadJSON(msg); err != nil {
if !websocket.IsCloseError(err, websocket.CloseNoStatusReceived) {
log.Trace().Err(err).Caller().Send()
}
_ = ws.Close()
break
}
log.Trace().Str("type", msg.Type).Msg("[api.ws] msg")
if handler := wsHandlers[msg.Type]; handler != nil {
go func() {
if err = handler(tr, msg); err != nil {
tr.Write(&Message{Type: "error", Value: msg.Type + ": " + err.Error()})
}
}()
}
}
tr.Close()
}
var wsUp *websocket.Upgrader
type Transport struct {
Request *http.Request
ctx map[any]any
closed bool
mx sync.Mutex
wrmx sync.Mutex
onChange func()
onWrite func(msg any)
onClose []func()
}
func (t *Transport) OnWrite(f func(msg any)) {
t.mx.Lock()
if t.onChange != nil {
t.onChange()
}
t.onWrite = f
t.mx.Unlock()
}
func (t *Transport) Write(msg any) {
t.wrmx.Lock()
t.onWrite(msg)
t.wrmx.Unlock()
}
func (t *Transport) Close() {
t.mx.Lock()
for _, f := range t.onClose {
f()
}
t.closed = true
t.mx.Unlock()
}
func (t *Transport) OnChange(f func()) {
t.mx.Lock()
t.onChange = f
t.mx.Unlock()
}
func (t *Transport) OnClose(f func()) {
t.mx.Lock()
if t.closed {
f()
} else {
t.onClose = append(t.onClose, f)
}
t.mx.Unlock()
}
// WithContext - run function with Context variable
func (t *Transport) WithContext(f func(ctx map[any]any)) {
t.mx.Lock()
if t.ctx == nil {
t.ctx = map[any]any{}
}
f(t.ctx)
t.mx.Unlock()
}