求助: tcp 连接转发

186 天前
 easymbol

目前我做测试想做一个特殊的转发,A 代码运行在公网监听两个端口,一个是用户请求;另一个是 B 代码的请求连接。A 接受到请求后将用户的请求转发到 B ,目前表象一直卡在了 B 回传回去的问题上(没有对 B 请求成功是否验证),请问各位大佬,这个如何操作

A 代码如下

package main

import (
	"fmt"
	"io"
	"net"
	"sync"
)

var (
	bConnMu sync.Mutex
	bConn   net.Conn
)

func handleUserRequest(userConn net.Conn) {
	defer func() {
		userConn.Close()
		fmt.Println("User connection closed")
	}()
	fmt.Println("Received user request")

	bConnMu.Lock()
	if bConn == nil {
		bConnMu.Unlock()
		fmt.Println("No connection to B machine")
		fmt.Fprintln(userConn, "No connection to B machine")
		return
	}
	bConnMu.Unlock()

	fmt.Println("Forwarding request to B machine")

	// 将用户请求转发给 B 机器
	err := forwardRequest(userConn, bConn)
	if err != nil {
		fmt.Println("Error forwarding request to B machine:", err)
		return
	}

	fmt.Println("User request completed")
}

func forwardRequest(userConn, bConn net.Conn) error {
	done := make(chan error, 1)

	// 将用户请求转发给 B 机器
	go func() {
		_, err := io.Copy(bConn, userConn)
		if err != nil {
			fmt.Println("Error forwarding request to B machine:", err)
		} else {
			fmt.Println("Finished forwarding request to B machine")
		}
		done <- err
	}()

	// 将 B 机器的响应转发给用户
	go func() {
		_, err := io.Copy(userConn, bConn)
		if err != nil {
			fmt.Println("Error forwarding response to user:", err)
		} else {
			fmt.Println("Finished forwarding response to user")
		}
		done <- err
	}()

	err := <-done
	if err != nil {
		return err
	}

	err = <-done
	return err
}

func handleBConnection(conn net.Conn) {
	fmt.Println("B machine connected")

	bConnMu.Lock()
	if bConn != nil {
		bConn.Close()
		fmt.Println("Closed previous connection to B machine")
	}
	bConn = conn
	bConnMu.Unlock()

	// 监听 B 机器的断开连接
	_, err := io.Copy(io.Discard, conn)
	if err != nil {
		fmt.Println("B machine disconnected with error:", err)
	} else {
		fmt.Println("B machine disconnected")
	}

	bConnMu.Lock()
	if bConn == conn {
		bConn = nil
		fmt.Println("Removed connection to B machine")
	}
	bConnMu.Unlock()
	conn.Close()
}

func main11() {
	// 启动监听用户代理请求的 goroutine
	fmt.Println("A machine listening for user requests on :12345")
	go func() {
		listener, err := net.Listen("tcp", ":12345")
		if err != nil {
			fmt.Println("Failed to listen for user requests:", err)
			return
		}
		defer listener.Close()

		for {
			conn, err := listener.Accept()
			if err != nil {
				fmt.Println("Failed to accept user request:", err)
				continue
			}

			go handleUserRequest(conn)
		}
	}()

	// 启动监听 B 机器连接的 goroutine
	fmt.Println("A machine listening for B machine connection on :12346")
	listener, err := net.Listen("tcp", ":12346")
	if err != nil {
		fmt.Println("Failed to listen for B machine connection:", err)
		return
	}
	defer listener.Close()

	for {
		conn, err := listener.Accept()
		if err != nil {
			fmt.Println("Failed to accept B machine connection:", err)
			continue
		}

		go handleBConnection(conn)
	}
}

B 代码如下

package main

import (
	"bufio"
	"fmt"
	"io"
	"log"
	"net"
	"net/http"
	"time"
)

func handleTunnel(conn net.Conn) {
	defer func() {
		conn.Close()
		fmt.Println("Tunnel connection closed")
	}()

	fmt.Println("Handling new tunnel connection")

	reader := bufio.NewReader(conn)
	for {
		conn.SetReadDeadline(time.Now().Add(30 * time.Second))

		request, err := http.ReadRequest(reader)
		if err != nil {
			if err == io.EOF {
				fmt.Println("A machine closed the connection")
				return
			}
			fmt.Println("Error reading request:", err)
			return
		}

		fmt.Printf("Received request from A machine: %s %s\n", request.Method, request.URL)

		if request.Method == http.MethodConnect {
			fmt.Printf("Processing CONNECT request: %s\n", request.URL.Host)

			targetConn, err := net.DialTimeout("tcp", request.URL.Host, 10*time.Second)
			if err != nil {
				log.Printf("Error connecting to target: %v\n", err)
				// 修改这里来手动发送 HTTP 响应
				conn.Write([]byte("HTTP/1.1 503 Service Unavailable\r\n\r\n"))
				return
			}

			conn.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n"))

			// 设置通道同步 goroutines
			done := make(chan struct{})

			// 开启 goroutine 处理从 client 到目标主机的流量
			go func() {
				io.Copy(targetConn, conn) // 假设 conn 是从 net.Listen 获取的连接
				close(done)
			}()

			// 开启 goroutine 处理从目标主机到 client 的流量
			go func() {
				io.Copy(conn, targetConn)
				close(done)
			}()

			// 等待至少一个方向的流完成
			<-done

			targetConn.Close()
			conn.Close()
		} else {
			fmt.Printf("Processing normal request: %s\n", request.URL)

			// 处理普通请求
			targetConn, err := net.DialTimeout("tcp", request.Host, 10*time.Second)
			if err != nil {
				fmt.Println("Error connecting to target host:", err)
				conn.Write([]byte("HTTP/1.1 503 Service Unavailable\r\n\r\n"))
				continue
			}
			defer func() {
				targetConn.Close()
				fmt.Printf("Target connection to %s closed\n", request.Host)
			}()

			fmt.Println("Connected to target host:", request.Host)

			// 将请求转发给目标主机
			err = request.Write(targetConn)
			if err != nil {
				fmt.Println("Error forwarding request to target host:", err)
				conn.Write([]byte("HTTP/1.1 503 Service Unavailable\r\n\r\n"))
				continue
			}
			fmt.Println("Forwarded request to target host")

			// 将目标主机的响应转发给 A 机器
			response, err := http.ReadResponse(bufio.NewReader(targetConn), request)
			if err != nil {
				fmt.Println("Error reading response from target host:", err)
				conn.Write([]byte("HTTP/1.1 503 Service Unavailable\r\n\r\n"))
				continue
			}
			fmt.Println("Read response from target host")

			err = response.Write(conn)
			if err != nil {
				fmt.Println("Error forwarding response to A machine:", err)
				continue
			}
			fmt.Println("Forwarded response to A machine")
		}
	}
}

func connectToA() {
	for {
		fmt.Println("Attempting to connect to A machine")
		conn, err := net.Dial("tcp", "A 机器地址:12346")
		if err != nil {
			fmt.Println("Failed to connect to A machine:", err)
			time.Sleep(5 * time.Second)
			continue
		}
		fmt.Println("Connected to A machine")

		// 发送心跳包
		go func() {
			for {
				_, err := conn.Write([]byte("ping"))
				if err != nil {
					fmt.Println("Failed to send heartbeat:", err)
					conn.Close()
					return
				}
				time.Sleep(5 * time.Second)
			}
		}()

		handleTunnel(conn)
	}
}

func main() {
	go connectToA()

	select {} // 阻塞主线程
}
1621 次点击
所在节点    Go 编程语言
9 条回复
mango88
186 天前
// 监听 B 机器的断开连接
_, err := io.Copy(io.Discard, conn)

这里占用了 socket 的输入流
lifei6671
186 天前
你这代码写的槽点太多了。首先你要理解,
1 、读写 net.Conn 必须独占不能共享。
2 、go 启动的协程必须有退出的时机,否则就会协程泄漏。
3 、chan 是不能多次 close 的,否则会 panic 。
4 、等待多个协程退出后继续执行,建议使用 sync.WaitGroup
xxxccc
186 天前
感觉你想要写一个 tcp proxy ,简单看了下代码,有很多奇怪的地方。建议你先让 chatgpt 帮你写一个 golang 的 tcp proxy
lt0136
186 天前
梳理一下你的需求:
B 的作用是 HTTP 代理,A 的作用是公网转发请求到 B

这个不需要自己写啊,B 启动一个的 HTTP 代理( squid ),再用内网穿透工具比如 frp 之类的把 B 的代理端口映射到 A 就好了
easymbol
186 天前
@lifei6671 😂 槽点的确多,这个并不擅长
gochat
186 天前
easymbol
186 天前
@gochat 好的,感谢
easymbol
186 天前
@lifei6671 好的,我尝试看看先
Ehco1996
186 天前

这是一个专为移动设备优化的页面(即为了让你能够在 Google 搜索结果里秒开这个页面),如果你希望参与 V2EX 社区的讨论,你可以继续到 V2EX 上打开本讨论主题的完整版本。

https://www.v2ex.com/t/1049602

V2EX 是创意工作者们的社区,是一个分享自己正在做的有趣事物、交流想法,可以遇见新朋友甚至新机会的地方。

V2EX is a community of developers, designers and creative people.

© 2021 V2EX