|
|
|
import asyncio |
|
from fastapi import FastAPI, WebSocket |
|
import uvicorn |
|
|
|
|
|
app = FastAPI() |
|
|
|
@app.websocket("/") |
|
async def tunnel(websocket: WebSocket): |
|
""" |
|
WebSocket 隧道入口: |
|
1. 接受客户端 WebSocket 连接; |
|
2. 等待客户端发送 CONNECT 请求,解析出目标主机及端口; |
|
3. 尝试与目标主机建立 TCP 连接; |
|
4. 返回 HTTP 200 成功建立隧道; |
|
5. 启动双向数据转发(TCP <--> WebSocket)。 |
|
""" |
|
await websocket.accept() |
|
try: |
|
|
|
|
|
|
|
|
|
|
|
request_text = await websocket.receive_text() |
|
lines = request_text.splitlines() |
|
if not lines: |
|
await websocket.send_text("HTTP/1.1 400 Bad Request\r\n\r\n") |
|
await websocket.close() |
|
return |
|
|
|
|
|
first_line = lines[0].strip() |
|
parts = first_line.split() |
|
if len(parts) < 3 or parts[0].upper() != "CONNECT": |
|
await websocket.send_text("HTTP/1.1 400 Bad Request\r\n\r\n") |
|
await websocket.close() |
|
return |
|
|
|
|
|
dest = parts[1] |
|
if ":" not in dest: |
|
await websocket.send_text("HTTP/1.1 400 Bad Request\r\n\r\n") |
|
await websocket.close() |
|
return |
|
|
|
dest_parts = dest.split(":", 1) |
|
dest_host = dest_parts[0] |
|
try: |
|
dest_port = int(dest_parts[1]) |
|
except Exception: |
|
await websocket.send_text("HTTP/1.1 400 Bad Request\r\n\r\n") |
|
await websocket.close() |
|
return |
|
|
|
|
|
|
|
|
|
try: |
|
reader, writer = await asyncio.open_connection(dest_host, dest_port) |
|
except Exception as e: |
|
err_msg = f"HTTP/1.1 502 Bad Gateway\r\n\r\n无法连接 {dest_host}:{dest_port},错误:{e}" |
|
await websocket.send_text(err_msg) |
|
await websocket.close() |
|
return |
|
|
|
|
|
|
|
|
|
await websocket.send_text("HTTP/1.1 200 Connection Established\r\n\r\n") |
|
|
|
|
|
|
|
|
|
async def tcp_to_ws(): |
|
""" |
|
从 TCP 连接中读取数据,通过 WebSocket 以二进制方式发送给客户端 |
|
""" |
|
try: |
|
while True: |
|
data = await reader.read(1024) |
|
if not data: |
|
break |
|
await websocket.send_bytes(data) |
|
except Exception as e: |
|
|
|
print("tcp_to_ws 异常:", e) |
|
|
|
async def ws_to_tcp(): |
|
""" |
|
从客户端通过 WebSocket 发送的数据写入 TCP 连接 |
|
""" |
|
try: |
|
while True: |
|
message = await websocket.receive() |
|
|
|
if "bytes" in message: |
|
data = message["bytes"] |
|
elif "text" in message: |
|
|
|
data = message["text"].encode("utf-8") |
|
else: |
|
break |
|
writer.write(data) |
|
await writer.drain() |
|
except Exception as e: |
|
print("ws_to_tcp 异常:", e) |
|
|
|
|
|
await asyncio.gather(tcp_to_ws(), ws_to_tcp()) |
|
except Exception as e: |
|
print("WebSocket 隧道处理异常:", e) |
|
finally: |
|
|
|
await websocket.close() |
|
|
|
|
|
if __name__ == "__main__": |
|
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False) |
|
|