Commit a9dafd10 by qlintonger xeno

新旧断线处理完毕

parent 175f2fc3
use crate::handle_messages::handle_other_message; use crate::handle_messages::handle_other_message;
use crate::heartbeat::handle_heartbeat; use crate::heartbeat::handle_heartbeat;
use crate::json_utils::{make_common_resp, parse_message}; use crate::json_utils::{make_common_resp, parse_message};
use crate::utils; use crate::{config, utils};
use futures::{SinkExt, StreamExt}; use futures::{SinkExt, StreamExt};
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::{Arc}; use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use tokio::task::JoinHandle;
use tokio::time; use tokio::time;
use tokio_tungstenite::{accept_hdr_async, WebSocketStream}; use tokio_tungstenite::{accept_hdr_async, WebSocketStream};
use tungstenite::handshake::server::{Request, Response}; use tungstenite::handshake::server::{Request, Response};
use tungstenite::{Error, Message}; use tungstenite::{Error, Message};
use tokio::sync::Mutex; use tokio::sync::Mutex as AsyncMutex;
use lazy_static::lazy_static; use lazy_static::lazy_static;
use redis::{Client}; use redis::{Client};
use redis_pool::{ SingleRedisPool}; use redis_pool::{SingleRedisPool};
use redis::Commands;
use redis::Commands; // 新增导入 use config::STATIC_WS_PWD;
lazy_static! { lazy_static! {
static ref REDIS_POOL: SingleRedisPool = { static ref REDIS_POOL: SingleRedisPool = {
...@@ -25,21 +26,37 @@ lazy_static! { ...@@ -25,21 +26,37 @@ lazy_static! {
} }
// 自定义结构体来存储发送器和接收器 // 自定义结构体来存储发送器和接收器
#[derive(Debug)]
struct Connection { struct Connection {
sender: futures::stream::SplitSink<WebSocketStream<tokio::net::TcpStream>, Message>, sender: futures::stream::SplitSink<WebSocketStream<tokio::net::TcpStream>, Message>,
receiver: futures::stream::SplitStream<WebSocketStream<tokio::net::TcpStream>>, receiver: futures::stream::SplitStream<WebSocketStream<tokio::net::TcpStream>>,
} }
// 全局连接映射,存储 fromId 到 Connection 的映射 // 全局连接映射,存储 fromId 到 Connection 的映射
type ConnectionMap = Arc<Mutex<HashMap<String, Connection>>>; type ConnectionMap = Arc<AsyncMutex<HashMap<String, Connection>>>;
// 全局任务映射,存储 fromId 到 JoinHandle 的映射
type TaskMap = Arc<Mutex<HashMap<String, JoinHandle<()>>>>;
lazy_static! { lazy_static! {
static ref CONNECTIONS: ConnectionMap = Arc::new(Mutex::new(HashMap::new())); static ref CONNECTIONS: ConnectionMap = Arc::new(AsyncMutex::new(HashMap::new()));
static ref TASKS: TaskMap = Arc::new(Mutex::new(HashMap::new()));
} }
// 关闭之前绑定的 WebSocket 连接 // 关闭之前绑定的 WebSocket 连接并取消对应的任务
async fn close_existing_connection(from_id: &str) { async fn close_existing_connection(from_id: &str) {
let mut connections = CONNECTIONS.lock(); let task_to_abort = {
if let Some(mut old_connection) = connections.await.remove(from_id) { let mut tasks = TASKS.lock().unwrap();
tasks.remove(from_id)
};
if let Some(task) = task_to_abort {
task.abort();
}
let mut connections = CONNECTIONS.lock().await;
let already_done = connections.get(&from_id.to_string());
println!("关闭之前绑定的 WebSocket 连接: {} {:?}", from_id, already_done);
if let Some(mut old_connection) = connections.remove(from_id) {
// 尝试优雅地关闭旧连接 // 尝试优雅地关闭旧连接
if let Err(e) = old_connection.sender.close().await { if let Err(e) = old_connection.sender.close().await {
println!("关闭旧的 WebSocket 发送器时出错: {}", e); println!("关闭旧的 WebSocket 发送器时出错: {}", e);
...@@ -48,15 +65,21 @@ async fn close_existing_connection(from_id: &str) { ...@@ -48,15 +65,21 @@ async fn close_existing_connection(from_id: &str) {
} }
} }
// 更新 Redis 中 connected 数组 // 更新 Redis 中 connected 集合
async fn update_connected_redis() { async fn update_connected_redis() {
let connections = CONNECTIONS.lock().await; let connections = CONNECTIONS.lock().await;
let from_ids: Vec<String> = connections.keys().cloned().collect(); let from_ids: Vec<String> = connections.keys().cloned().collect();
let mut con = REDIS_POOL.get_connection().expect("Failed to get Redis connection"); let mut con = REDIS_POOL.get_connection().expect("Failed to get Redis connection");
// 将每个 fromId 依次添加到列表中
// 先清空集合
if let Err(e) = con.del::<_, ()>("connected") {
println!("Failed to delete connected key in Redis: {}", e);
}
// 将每个 fromId 依次添加到集合中
for from_id in from_ids { for from_id in from_ids {
if let Err(e) = con.rpush::<_, _, ()>("connected", from_id) { if let Err(e) = con.sadd::<_, _, ()>("connected", from_id) {
println!("Failed to add fromId to connected list in Redis: {}", e); println!("Failed to add fromId to connected set in Redis: {}", e);
} }
} }
} }
...@@ -86,6 +109,11 @@ fn handle_handshake( ...@@ -86,6 +109,11 @@ fn handle_handshake(
return Err(error_msg); return Err(error_msg);
} }
if connection_params.get("wsPwd").unwrap() != (STATIC_WS_PWD) {
println!("wsPwd不正确!");
return Err("wsPwd不正确!".to_string());
}
Ok(connection_params) Ok(connection_params)
} }
...@@ -100,7 +128,6 @@ pub(crate) async fn handle_client(stream: tokio::net::TcpStream) -> Result<(), E ...@@ -100,7 +128,6 @@ pub(crate) async fn handle_client(stream: tokio::net::TcpStream) -> Result<(), E
Ok(resp) Ok(resp)
} }
Err(error_msg) => { Err(error_msg) => {
// 可以在这里根据需要进一步处理错误响应,这里简单打印错误信息
println!("{}", error_msg); println!("{}", error_msg);
let error_resp = Response::builder() let error_resp = Response::builder()
.status(400) .status(400)
...@@ -120,39 +147,41 @@ pub(crate) async fn handle_client(stream: tokio::net::TcpStream) -> Result<(), E ...@@ -120,39 +147,41 @@ pub(crate) async fn handle_client(stream: tokio::net::TcpStream) -> Result<(), E
} }
}; };
// 将 WebSocketStream 拆分为发送器和接收器
let (sender, receiver) = ws_stream.split(); let (sender, receiver) = ws_stream.split();
if let Some(params) = connection_params { if let Some(params) = connection_params {
if let Some(from_id) = params.get("fromId") { if let Some(from_id) = params.get("fromId") {
let from_id = from_id.clone(); let from_id = from_id.clone();
let from_id_clone = from_id.clone(); // 新增:克隆一份 from_id 用于闭包
// 从 Redis 连接池获取连接
let con = REDIS_POOL.get_connection().expect("Failed to get Redis connection");
// 检查 Redis 中是否已经存在该 fromId // 检查 Redis 中是否已经存在该 fromId
close_existing_connection(&from_id).await; close_existing_connection(&from_id).await;
// 将新连接添加到全局连接映射 // 将新连接添加到全局连接映射
CONNECTIONS.lock().await.insert(from_id.clone(), Connection { sender, receiver }); {
let mut connections = CONNECTIONS.lock().await;
connections.insert(from_id.clone(), Connection { sender, receiver });
}
// 更新 Redis 中的 connected 数组 // 更新 Redis 中的 connected 集合
update_connected_redis().await; update_connected_redis().await;
let task = tokio::spawn(async move {
let mut last_heartbeat_time = Instant::now(); let mut last_heartbeat_time = Instant::now();
loop { loop {
let mut connections = CONNECTIONS.lock().await; let mut connections = CONNECTIONS.lock().await;
let current_connection = connections.get_mut(&from_id).unwrap(); if let Some(current_connection) = connections.get_mut(&from_id_clone) { // 使用克隆后的 from_id
let receiver = &mut current_connection.receiver; let receiver_ref = &mut current_connection.receiver;
let sender = &mut current_connection.sender; let sender_ref = &mut current_connection.sender;
tokio::select! { tokio::select! {
// 处理消息接收 // 处理消息接收
maybe_msg = receiver.next() => { maybe_msg = receiver_ref.next() => {
match maybe_msg { match maybe_msg {
Some(Ok(msg)) => { Some(Ok(msg)) => {
if msg.is_text() { if msg.is_text() {
let text = msg.to_text()?; let text = msg.to_text().unwrap();
match parse_message(text) { match parse_message(text) {
Ok(data) => { Ok(data) => {
match data.msg_type.as_str() { match data.msg_type.as_str() {
...@@ -160,11 +189,11 @@ pub(crate) async fn handle_client(stream: tokio::net::TcpStream) -> Result<(), E ...@@ -160,11 +189,11 @@ pub(crate) async fn handle_client(stream: tokio::net::TcpStream) -> Result<(), E
println!("收到客户端心跳消息 {:?}", &data); println!("收到客户端心跳消息 {:?}", &data);
handle_heartbeat(&mut last_heartbeat_time); handle_heartbeat(&mut last_heartbeat_time);
if let Ok(json_str) = make_common_resp(Default::default(), "Heart") { if let Ok(json_str) = make_common_resp(Default::default(), "Heart") {
sender.send(Message::text(json_str)).await?; sender_ref.send(Message::text(json_str)).await.unwrap();
} }
}, },
_ => { _ => {
handle_other_message(sender, &data).await?; handle_other_message(sender_ref, &data).await.unwrap();
} }
} }
} }
...@@ -186,17 +215,28 @@ pub(crate) async fn handle_client(stream: tokio::net::TcpStream) -> Result<(), E ...@@ -186,17 +215,28 @@ pub(crate) async fn handle_client(stream: tokio::net::TcpStream) -> Result<(), E
} }
// 处理心跳超时 // 处理心跳超时
_ = time::sleep_until(tokio::time::Instant::from(last_heartbeat_time + Duration::from_secs(20))) => { _ = time::sleep_until(tokio::time::Instant::from(last_heartbeat_time + Duration::from_secs(20))) => {
println!("用户id-{} 20秒内没有发送心跳,挂断连接", from_id); println!("用户id-{} 20秒内没有发送心跳,挂断连接", from_id_clone); // 使用克隆后的 from_id
break; break;
} }
} }
} else {
break;
}
} }
println!("断开与用户id: {},连接", from_id); println!("断开与用户id: {},连接", from_id_clone); // 使用克隆后的 from_id
// 从全局连接映射中移除该连接 // 从全局连接映射中移除该连接
CONNECTIONS.lock().await.remove(&from_id); {
// 更新 Redis 中的 connected 数组 let mut connections = CONNECTIONS.lock().await;
connections.remove(&from_id_clone); // 使用克隆后的 from_id
}
// 更新 Redis 中的 connected 集合
update_connected_redis().await; update_connected_redis().await;
});
// 将任务句柄存储到全局任务映射中
let mut tasks = TASKS.lock().unwrap();
tasks.insert(from_id, task); // 使用原始的 from_id
} }
} else { } else {
println!("无法获取连接参数"); println!("无法获取连接参数");
......
pub const STATIC_WS_PWD: &str = "Q8kFm5LzJ2Ab";
\ No newline at end of file
...@@ -5,13 +5,14 @@ mod utils; ...@@ -5,13 +5,14 @@ mod utils;
mod json_utils; mod json_utils;
mod heartbeat; mod heartbeat;
mod handle_messages; mod handle_messages;
mod config;
use client::handle_client; use client::handle_client;
use tokio::net::TcpListener; use tokio::net::TcpListener;
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
let addr = "0.0.0.0:8080"; let addr = "0.0.0.0:12345";
let listener = TcpListener::bind(addr).await.unwrap(); let listener = TcpListener::bind(addr).await.unwrap();
while let Ok((stream, _)) = listener.accept().await { while let Ok((stream, _)) = listener.accept().await {
tokio::spawn(handle_client(stream)); tokio::spawn(handle_client(stream));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment