Go 服务端实现
Go 语言以其出色的并发性能成为构建 WebSocket 服务器的优秀选择。本章介绍使用 gorilla/websocket 库实现 WebSocket 服务。
安装
go get github.com/gorilla/websocket
基本服务器
创建一个简单的 WebSocket 服务器:
package main
import (
"log"
"net/http"
"github.com/gorilla/websocket"
)
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}
func handleWebSocket(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Println("升级失败:", err)
return
}
defer conn.Close()
log.Println("客户端已连接")
for {
messageType, message, err := conn.ReadMessage()
if err != nil {
log.Println("读取消息失败:", err)
break
}
log.Printf("收到消息: %s", message)
err = conn.WriteMessage(messageType, message)
if err != nil {
log.Println("发送消息失败:", err)
break
}
}
}
func main() {
http.HandleFunc("/ws", handleWebSocket)
log.Println("WebSocket 服务器运行在 :8080")
log.Fatal(http.ListenAndServe(":8080", nil))
}
Upgrader 配置
Upgrader 用于将 HTTP 连接升级为 WebSocket 连接:
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
origin := r.Header.Get("Origin")
allowedOrigins := []string{"http://localhost:3000", "https://example.com"}
for _, allowed := range allowedOrigins {
if origin == allowed {
return true
}
}
return false
},
HandshakeTimeout: 10 * time.Second,
}
CheckOrigin 验证
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
origin := r.Header.Get("Origin")
return origin == "http://localhost:3000"
},
}
允许所有来源(开发环境):
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true
},
}
读写消息
读取消息
for {
messageType, message, err := conn.ReadMessage()
if err != nil {
log.Println("读取失败:", err)
break
}
switch messageType {
case websocket.TextMessage:
log.Printf("文本消息: %s", message)
case websocket.BinaryMessage:
log.Printf("二进制消息: %d 字节", len(message))
}
}
写入消息
err := conn.WriteMessage(websocket.TextMessage, []byte("文本消息"))
if err != nil {
log.Println("发送失败:", err)
}
err = conn.WriteMessage(websocket.BinaryMessage, []byte{1, 2, 3, 4})
if err != nil {
log.Println("发送失败:", err)
}
JSON 消息
type Message struct {
Type string `json:"type"`
Content string `json:"content"`
}
var msg Message
err := conn.ReadJSON(&msg)
if err != nil {
log.Println("JSON 解析失败:", err)
return
}
err = conn.WriteJSON(Message{
Type: "greeting",
Content: "欢迎连接",
})
if err != nil {
log.Println("发送 JSON 失败:", err)
}
连接管理
连接管理器
package main
import (
"log"
"net/http"
"sync"
"github.com/gorilla/websocket"
)
type ConnectionManager struct {
connections map[*websocket.Conn]bool
mutex sync.RWMutex
}
func NewConnectionManager() *ConnectionManager {
return &ConnectionManager{
connections: make(map[*websocket.Conn]bool),
}
}
func (m *ConnectionManager) Add(conn *websocket.Conn) {
m.mutex.Lock()
defer m.mutex.Unlock()
m.connections[conn] = true
log.Printf("连接已添加,当前连接数: %d", len(m.connections))
}
func (m *ConnectionManager) Remove(conn *websocket.Conn) {
m.mutex.Lock()
defer m.mutex.Unlock()
delete(m.connections, conn)
log.Printf("连接已移除,当前连接数: %d", len(m.connections))
}
func (m *ConnectionManager) Broadcast(message []byte) {
m.mutex.RLock()
defer m.mutex.RUnlock()
for conn := range m.connections {
err := conn.WriteMessage(websocket.TextMessage, message)
if err != nil {
log.Println("广播失败:", err)
conn.Close()
delete(m.connections, conn)
}
}
}
func (m *ConnectionManager) BroadcastExcept(message []byte, exclude *websocket.Conn) {
m.mutex.RLock()
defer m.mutex.RUnlock()
for conn := range m.connections {
if conn != exclude {
err := conn.WriteMessage(websocket.TextMessage, message)
if err != nil {
log.Println("广播失败:", err)
}
}
}
}
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
}
var manager = NewConnectionManager()
func handleWebSocket(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Println("升级失败:", err)
return
}
defer conn.Close()
manager.Add(conn)
defer manager.Remove(conn)
for {
_, message, err := conn.ReadMessage()
if err != nil {
log.Println("读取失败:", err)
break
}
manager.BroadcastExcept(message, conn)
}
}
func main() {
http.HandleFunc("/ws", handleWebSocket)
log.Println("服务器运行在 :8080")
log.Fatal(http.ListenAndServe(":8080", nil))
}
心跳检测
package main
import (
"log"
"net/http"
"time"
"github.com/gorilla/websocket"
)
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
}
const (
writeWait = 10 * time.Second
pongWait = 60 * time.Second
pingPeriod = (pongWait * 9) / 10
maxMessageSize = 512
)
func handleWebSocket(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Println("升级失败:", err)
return
}
defer conn.Close()
conn.SetReadLimit(maxMessageSize)
conn.SetReadDeadline(time.Now().Add(pongWait))
conn.SetPongHandler(func(string) error {
conn.SetReadDeadline(time.Now().Add(pongWait))
return nil
})
go func() {
ticker := time.NewTicker(pingPeriod)
defer ticker.Stop()
for {
select {
case <-ticker.C:
conn.SetWriteDeadline(time.Now().Add(writeWait))
if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil {
log.Println("Ping 失败:", err)
return
}
}
}
}()
for {
_, message, err := conn.ReadMessage()
if err != nil {
log.Println("读取失败:", err)
break
}
log.Printf("收到消息: %s", message)
conn.WriteMessage(websocket.TextMessage, message)
}
}
func main() {
http.HandleFunc("/ws", handleWebSocket)
log.Println("服务器运行在 :8080")
log.Fatal(http.ListenAndServe(":8080", nil))
}
聊天室示例
package main
import (
"encoding/json"
"log"
"net/http"
"sync"
"time"
"github.com/gorilla/websocket"
)
type Message struct {
Type string `json:"type"`
Username string `json:"username,omitempty"`
Content string `json:"content,omitempty"`
Timestamp string `json:"timestamp"`
}
type Client struct {
conn *websocket.Conn
username string
room *Room
send chan []byte
}
type Room struct {
name string
clients map[*Client]bool
messages []Message
mutex sync.RWMutex
}
func NewRoom(name string) *Room {
return &Room{
name: name,
clients: make(map[*Client]bool),
}
}
func (r *Room) join(client *Client) {
r.mutex.Lock()
r.clients[client] = true
r.mutex.Unlock()
r.mutex.RLock()
for _, msg := range r.messages {
data, _ := json.Marshal(msg)
client.send <- data
}
r.mutex.RUnlock()
r.broadcast(Message{
Type: "system",
Content: client.username + " 加入了聊天室",
Timestamp: time.Now().Format(time.RFC3339),
}, nil)
}
func (r *Room) leave(client *Client) {
r.mutex.Lock()
delete(r.clients, client)
r.mutex.Unlock()
r.broadcast(Message{
Type: "system",
Content: client.username + " 离开了聊天室",
Timestamp: time.Now().Format(time.RFC3339),
}, nil)
}
func (r *Room) broadcast(message Message, exclude *Client) {
r.mutex.Lock()
r.messages = append(r.messages, message)
if len(r.messages) > 100 {
r.messages = r.messages[1:]
}
r.mutex.Unlock()
data, _ := json.Marshal(message)
r.mutex.RLock()
defer r.mutex.RUnlock()
for client := range r.clients {
if client != exclude {
select {
case client.send <- data:
default:
close(client.send)
delete(r.clients, client)
}
}
}
}
type ChatServer struct {
rooms map[string]*Room
mutex sync.RWMutex
}
func NewChatServer() *ChatServer {
return &ChatServer{
rooms: make(map[string]*Room),
}
}
func (s *ChatServer) getRoom(name string) *Room {
s.mutex.Lock()
defer s.mutex.Unlock()
if room, ok := s.rooms[name]; ok {
return room
}
room := NewRoom(name)
s.rooms[name] = room
return room
}
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
}
var chatServer = NewChatServer()
func handleWebSocket(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Println("升级失败:", err)
return
}
roomName := r.URL.Query().Get("room")
if roomName == "" {
roomName = "general"
}
username := r.URL.Query().Get("username")
if username == "" {
username = "匿名"
}
room := chatServer.getRoom(roomName)
client := &Client{
conn: conn,
username: username,
room: room,
send: make(chan []byte, 256),
}
room.join(client)
go client.writePump()
client.readPump()
}
func (c *Client) readPump() {
defer func() {
c.room.leave(c)
c.conn.Close()
}()
for {
_, message, err := c.conn.ReadMessage()
if err != nil {
break
}
var data map[string]interface{}
if err := json.Unmarshal(message, &data); err != nil {
continue
}
c.room.broadcast(Message{
Type: "message",
Username: c.username,
Content: data["content"].(string),
Timestamp: time.Now().Format(time.RFC3339),
}, c)
}
}
func (c *Client) writePump() {
defer c.conn.Close()
for message := range c.send {
if err := c.conn.WriteMessage(websocket.TextMessage, message); err != nil {
break
}
}
}
func main() {
http.HandleFunc("/ws", handleWebSocket)
log.Println("聊天服务器运行在 :8080")
log.Fatal(http.ListenAndServe(":8080", nil))
}
使用 Gin 框架
package main
import (
"log"
"net/http"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
)
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true
},
}
func handleWebSocket(c *gin.Context) {
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
log.Println("升级失败:", err)
return
}
defer conn.Close()
for {
messageType, message, err := conn.ReadMessage()
if err != nil {
log.Println("读取失败:", err)
break
}
log.Printf("收到消息: %s", message)
conn.WriteMessage(messageType, message)
}
}
func main() {
r := gin.Default()
r.GET("/ws", handleWebSocket)
r.GET("/", func(c *gin.Context) {
c.HTML(http.StatusOK, "index.html", nil)
})
r.LoadHTMLFiles("index.html")
log.Println("服务器运行在 :8080")
r.Run(":8080")
}
生产环境配置
优雅关闭
package main
import (
"context"
"log"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"github.com/gorilla/websocket"
)
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
}
func handleWebSocket(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
for {
_, message, err := conn.ReadMessage()
if err != nil {
break
}
conn.WriteMessage(websocket.TextMessage, message)
}
}
func main() {
server := &http.Server{
Addr: ":8080",
Handler: http.HandlerFunc(handleWebSocket),
}
go func() {
log.Println("服务器运行在 :8080")
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatal("服务器错误:", err)
}
}()
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
<-quit
log.Println("正在关闭服务器...")
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := server.Shutdown(ctx); err != nil {
log.Fatal("服务器关闭错误:", err)
}
log.Println("服务器已关闭")
}
Nginx 配置
upstream websocket {
server 127.0.0.1:8080;
}
server {
listen 80;
server_name example.com;
location /ws {
proxy_pass http://websocket;
proxy_http_version 1.1;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "upgrade";
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_read_timeout 86400;
}
}
小结
本章介绍了使用 Go 语言实现 WebSocket 服务器:
- gorilla/websocket:Go 中最流行的 WebSocket 库
- Upgrader 配置:连接升级、Origin 验证
- 消息读写:文本、二进制、JSON 消息
- 连接管理:线程安全的连接管理器
- 心跳检测:Ping/Pong 机制
- Gin 集成:与 Gin 框架配合使用
Go 语言的高并发特性使其非常适合构建高性能 WebSocket 服务器。