/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package getty import ( "context" "crypto/tls" "crypto/x509" "fmt" "io/ioutil" "net" "net/http" "strings" "sync" "time" ) import ( gxnet "github.com/dubbogo/gost/net" gxsync "github.com/dubbogo/gost/sync" gxtime "github.com/dubbogo/gost/time" "github.com/gorilla/websocket" perrors "github.com/pkg/errors" uatomic "go.uber.org/atomic" ) var ( errSelfConnect = perrors.New("connect self!") serverFastFailTimeout = time.Second * 1 serverID uatomic.Int32 ) // Server interface type Server interface { EndPoint } // StreamServer is like tcp/websocket/wss server type StreamServer interface { Server // Listener get the network listener Listener() net.Listener } // PacketServer is like udp listen endpoint type PacketServer interface { Server // PacketConn get the network listener PacketConn() net.PacketConn } type server struct { ServerOptions // endpoint ID endPointID EndPointID // net pktListener net.PacketConn streamListener net.Listener lock sync.Mutex // for server endPointType EndPointType server *http.Server // for ws or wss server sync.Once done chan struct{} wg sync.WaitGroup } func (s *server) init(opts ...ServerOption) { for _, opt := range opts { opt(&(s.ServerOptions)) } } func newServer(t EndPointType, opts ...ServerOption) *server { s := &server{ endPointID: serverID.Add(1), endPointType: t, done: make(chan struct{}), } s.init(opts...) return s } // NewTCPServer builds a tcp server. func NewTCPServer(opts ...ServerOption) Server { return newServer(TCP_SERVER, opts...) } // NewUDPEndPoint builds a unconnected udp server. func NewUDPEndPoint(opts ...ServerOption) Server { return newServer(UDP_ENDPOINT, opts...) } // NewWSServer builds a websocket server. func NewWSServer(opts ...ServerOption) Server { return newServer(WS_SERVER, opts...) } // NewWSSServer builds a secure websocket server. func NewWSSServer(opts ...ServerOption) Server { s := newServer(WSS_SERVER, opts...) if s.addr == "" || s.cert == "" || s.privateKey == "" { panic(fmt.Sprintf("@addr:%s, @cert:%s, @privateKey:%s, @caCert:%s", s.addr, s.cert, s.privateKey, s.caCert)) } return s } func (s *server) ID() int32 { return s.endPointID } func (s *server) EndPointType() EndPointType { return s.endPointType } func (s *server) stop() { select { case <-s.done: return default: s.Once.Do(func() { close(s.done) s.lock.Lock() if s.server != nil { ctx, cancel := context.WithTimeout(context.Background(), serverFastFailTimeout) if err := s.server.Shutdown(ctx); err != nil { // if the log output is "shutdown ctx: context deadline exceeded", it means that // there are still some active connections. log.Errorf("server shutdown ctx:%s error:%v", ctx, err) } cancel() } s.server = nil s.lock.Unlock() if s.streamListener != nil { // let the server exit asap when got error from RunEventLoop. s.streamListener.Close() s.streamListener = nil } if s.pktListener != nil { s.pktListener.Close() s.pktListener = nil } }) } } func (s *server) GetTaskPool() gxsync.GenericTaskPool { return s.tPool } func (s *server) IsClosed() bool { select { case <-s.done: return true default: return false } } // net.ipv4.tcp_max_syn_backlog // net.ipv4.tcp_timestamps // net.ipv4.tcp_tw_recycle func (s *server) listenTCP() error { var ( err error streamListener net.Listener ) if len(s.addr) == 0 || !strings.Contains(s.addr, ":") { streamListener, err = gxnet.ListenOnTCPRandomPort(s.addr) if err != nil { return perrors.Wrapf(err, "gxnet.ListenOnTCPRandomPort(addr:%s)", s.addr) } } else { if s.sslEnabled { if sslConfig, buildTlsConfErr := s.tlsConfigBuilder.BuildTlsConfig(); buildTlsConfErr == nil && sslConfig != nil { streamListener, err = tls.Listen("tcp", s.addr, sslConfig) } } else { streamListener, err = net.Listen("tcp", s.addr) } if err != nil { return perrors.Wrapf(err, "net.Listen(tcp, addr:%s)", s.addr) } } s.streamListener = streamListener s.addr = s.streamListener.Addr().String() return nil } func (s *server) listenUDP() error { var ( err error localAddr *net.UDPAddr pktListener *net.UDPConn ) if len(s.addr) == 0 || !strings.Contains(s.addr, ":") { pktListener, err = gxnet.ListenOnUDPRandomPort(s.addr) if err != nil { return perrors.Wrapf(err, "gxnet.ListenOnUDPRandomPort(addr:%s)", s.addr) } } else { localAddr, err = net.ResolveUDPAddr("udp", s.addr) if err != nil { return perrors.Wrapf(err, "net.ResolveUDPAddr(udp, addr:%s)", s.addr) } pktListener, err = net.ListenUDP("udp", localAddr) if err != nil { return perrors.Wrapf(err, "net.ListenUDP((udp, localAddr:%#v)", localAddr) } } s.pktListener = pktListener s.addr = s.pktListener.LocalAddr().String() return nil } // Listen announces on the local network address. func (s *server) listen() error { switch s.endPointType { case TCP_SERVER, WS_SERVER, WSS_SERVER: return perrors.WithStack(s.listenTCP()) case UDP_ENDPOINT: return perrors.WithStack(s.listenUDP()) } return nil } func (s *server) accept(newSession NewSessionCallback) (Session, error) { conn, err := s.streamListener.Accept() if err != nil { return nil, perrors.WithStack(err) } if gxnet.IsSameAddr(conn.RemoteAddr(), conn.LocalAddr()) { log.Warnf("conn.localAddr{%s} == conn.RemoteAddr", conn.LocalAddr().String(), conn.RemoteAddr().String()) return nil, perrors.WithStack(errSelfConnect) } ss := newTCPSession(conn, s) err = newSession(ss) if err != nil { conn.Close() return nil, perrors.WithStack(err) } return ss, nil } func (s *server) runTCPEventLoop(newSession NewSessionCallback) { s.wg.Add(1) go func() { defer s.wg.Done() var ( err error client Session delay time.Duration ) for { if s.IsClosed() { log.Infof("server{%s} stop accepting client connect request.", s.addr) return } if delay != 0 { <-gxtime.After(delay) } client, err = s.accept(newSession) if err != nil { if netErr, ok := perrors.Cause(err).(net.Error); ok && netErr.Temporary() { if delay == 0 { delay = 5 * time.Millisecond } else { delay *= 2 } if max := 1 * time.Second; delay > max { delay = max } continue } log.Warnf("server{%s}.Accept() = err {%+v}", s.addr, perrors.WithStack(err)) continue } delay = 0 client.(*session).run() } }() } func (s *server) runUDPEventLoop(newSession NewSessionCallback) { s.wg.Add(1) go func() { defer s.wg.Done() var ( err error conn *net.UDPConn ss Session ) conn = s.pktListener.(*net.UDPConn) ss = newUDPSession(conn, s) if err = newSession(ss); err != nil { conn.Close() panic(err.Error()) } ss.(*session).run() }() } type wsHandler struct { http.ServeMux server *server newSession NewSessionCallback upgrader websocket.Upgrader } func newWSHandler(server *server, newSession NewSessionCallback) *wsHandler { return &wsHandler{ server: server, newSession: newSession, upgrader: websocket.Upgrader{ // in default, ReadBufferSize & WriteBufferSize is 4k // HandshakeTimeout: server.HTTPTimeout, CheckOrigin: func(_ *http.Request) bool { return true }, // allow connections from any origin EnableCompression: true, }, } } func (s *wsHandler) serveWSRequest(w http.ResponseWriter, r *http.Request) { if r.Method != "GET" { // w.WriteHeader(http.StatusMethodNotAllowed) http.Error(w, "Method not allowed", 405) return } if s.server.IsClosed() { http.Error(w, "HTTP server is closed(code:500-11).", 500) log.Warnf("server{%s} stop acceptting client connect request.", s.server.addr) return } conn, err := s.upgrader.Upgrade(w, r, nil) if err != nil { log.Warnf("upgrader.Upgrader(http.Request{%#v}) = error:%+v", r, err) return } if conn.RemoteAddr().String() == conn.LocalAddr().String() { log.Warnf("conn.localAddr{%s} == conn.RemoteAddr", conn.LocalAddr().String(), conn.RemoteAddr().String()) return } // conn.SetReadLimit(int64(handler.maxMsgLen)) ss := newWSSession(conn, s.server) err = s.newSession(ss) if err != nil { conn.Close() log.Warnf("server{%s}.newSession(ss{%#v}) = err {%s}", s.server.addr, ss, err) return } if ss.(*session).maxMsgLen > 0 { conn.SetReadLimit(int64(ss.(*session).maxMsgLen)) } ss.(*session).run() } // runWSEventLoop serve websocket client request // @newSession: new websocket connection callback func (s *server) runWSEventLoop(newSession NewSessionCallback) { s.wg.Add(1) go func() { defer s.wg.Done() var ( err error handler *wsHandler server *http.Server ) handler = newWSHandler(s, newSession) handler.HandleFunc(s.path, handler.serveWSRequest) server = &http.Server{ Addr: s.addr, Handler: handler, // ReadTimeout: server.HTTPTimeout, // WriteTimeout: server.HTTPTimeout, } s.lock.Lock() s.server = server s.lock.Unlock() err = server.Serve(s.streamListener) if err != nil { log.Errorf("http.server.Serve(addr{%s}) = err:%+v", s.addr, perrors.WithStack(err)) } }() } // serve websocket client request // RunWSSEventLoop serve websocket client request func (s *server) runWSSEventLoop(newSession NewSessionCallback) { s.wg.Add(1) go func() { var ( err error certPem []byte certificate tls.Certificate certPool *x509.CertPool config *tls.Config handler *wsHandler server *http.Server ) defer s.wg.Done() if certificate, err = tls.LoadX509KeyPair(s.cert, s.privateKey); err != nil { panic(fmt.Sprintf("tls.LoadX509KeyPair(certs{%s}, privateKey{%s}) = err:%+v", s.cert, s.privateKey, perrors.WithStack(err))) } config = &tls.Config{ InsecureSkipVerify: true, // do not verify peer certs ClientAuth: tls.NoClientCert, NextProtos: []string{"http/1.1"}, Certificates: []tls.Certificate{certificate}, } if s.caCert != "" { certPem, err = ioutil.ReadFile(s.caCert) if err != nil { panic(fmt.Errorf("ioutil.ReadFile(certFile{%s}) = err:%+v", s.caCert, perrors.WithStack(err))) } certPool = x509.NewCertPool() if ok := certPool.AppendCertsFromPEM(certPem); !ok { panic("failed to parse root certificate file") } config.ClientCAs = certPool config.ClientAuth = tls.RequireAndVerifyClientCert config.InsecureSkipVerify = false } handler = newWSHandler(s, newSession) handler.HandleFunc(s.path, handler.serveWSRequest) server = &http.Server{ Addr: s.addr, Handler: handler, // ReadTimeout: server.HTTPTimeout, // WriteTimeout: server.HTTPTimeout, } server.SetKeepAlivesEnabled(true) s.lock.Lock() s.server = server s.lock.Unlock() err = server.Serve(tls.NewListener(s.streamListener, config)) if err != nil { log.Errorf("http.server.Serve(addr{%s}) = err:%+v", s.addr, perrors.WithStack(err)) panic(err) } }() } // RunEventLoop serves client request. // @newSession: new connection callback func (s *server) RunEventLoop(newSession NewSessionCallback) { if err := s.listen(); err != nil { panic(fmt.Errorf("server.listen() = error:%+v", perrors.WithStack(err))) } switch s.endPointType { case TCP_SERVER: s.runTCPEventLoop(newSession) case UDP_ENDPOINT: s.runUDPEventLoop(newSession) case WS_SERVER: s.runWSEventLoop(newSession) case WSS_SERVER: s.runWSSEventLoop(newSession) default: panic(fmt.Sprintf("illegal server type %s", s.endPointType.String())) } } func (s *server) Listener() net.Listener { return s.streamListener } func (s *server) PacketConn() net.PacketConn { return s.pktListener } func (s *server) Close() { s.stop() s.wg.Wait() }