DEV Community

William Gough
William Gough

Posted on • Originally published at dev.to on

How to test TCP/UDP connections in Go - Part 1

This article was originally posted on my blog - How to test TCP/UDP connections in Go - Part 1

Introduction

For a recent work task, I had to expose a key-value store (which was normally accessed through REST API) via packet and stream based protocols such as TCP and UDP in a REST-esque manner. This presented an exciting challenge for me as I am relatively new to writing Go professionally and I wanted to ensure a number of things:

  1. The overall integrity and reliability of the software I was writing
  2. Cleaner code
  3. Rapid development cycle
  4. Regression detection

I had no specific requirements given on how to interact with the exposed interface, but I knew writing tests would be the best way to solve a new challenge in the most effective and efficient way.

This tutorial will demonstrate my approach to testing TCP/UDP connections in Golang in a simple yet effective way. The step-by-step instructions will provide guidance to verify your network connections produce the desired output, especially if you’re new to Go or testing (like me).

I’m not going to re-implement the original project for the sake of this post, however, I am going to create something similar to demonstrate the problem and what I found to be a sensible and quick approach to solving it.

Setting up

Firstly, let's create a new working directory and create 2 new Go files. I normally work inside /go/src/github.com/me, but it's totally up to you:

cd $GOPATH/src/github.com/<you>
mkdir net-testing && cd $_
touch net.go net_test.go
Enter fullscreen mode Exit fullscreen mode

Fire up your favorite editor (I'm using VS Code; it has great support for Go with auto formatting and auto-test on save) and let's start building out our application. First up, in net_test.go, we need to bootstrap the application to start when the tests run.

N.B: Running a network server, whether its TCP/UDP/HTTP, will block the main Goroutine until the application is shut down or the server errors.

package net

import (
    "log"
)

var srv Server

func init() {
    // Start the new server.
    srv, err := NewServer("tcp", ":1123")
    if err != nil {
        log.Println("error starting TCP server")
        return
    }

    // Run the server in Goroutine to stop tests from blocking
    // test execution.
    go func() {
        srv.Run()
    }()
}
Enter fullscreen mode Exit fullscreen mode

The above shows how we declare a global variable of type Server, which we haven't created yet, then use the init function to bootstrap our server in a new Goroutine. Next, we're going to write a test to verify that the server has started.

// Be sure to update your imports
import (
    "log"
    "net"
    "testing"
)


// Below init function
func TestNETServer_Run(t *testing.T) {
    // Simply check that the server is up and can
    // accept connections.
    conn, err := net.Dial("tcp", ":1123")
    if err != nil {
        t.Error("could not connect to server: ", err)
    }
    defer conn.Close()
}
Enter fullscreen mode Exit fullscreen mode

Running the tests will inevitably result in a failed build right now, but don't worry about that. Now that we have our test ready to check our server runs, let's start the server implementation. Inside net.go, add the following code:

package net

import (
    "errors"
    "net"
    "strings"
)

// Server defines the minimum contract our
// TCP and UDP server implementations must satisfy.
type Server interface {
    Run() error
    Close() error
}

// NewServer creates a new Server using given protocol
// and addr.
func NewServer(protocol, addr string) (Server, error) {
    switch strings.ToLower(protocol) {
    case "tcp":
        return &TCPServer{
            addr: addr,
        }, nil
    case "udp":
    }
    return nil, errors.New("Invalid protocol given")
}

// TCPServer holds the structure of our TCP
// implementation.
type TCPServer struct {
    addr   string
    server net.Listener
}

// Run starts the TCP Server.
func (t *TCPServer) Run() (err error) {
    t.server, err = net.Listen("tcp", t.addr)
    if err != nil {
        return
    }
    for {
        conn, err := t.server.Accept()
        if err != nil {
            err = errors.New("could not accept connection")
            break
        }
        if conn == nil {
            err = errors.New("could not create connection")
            break
        }
        conn.Close()
    }
    return
}

// Close shuts down the TCP Server
func (t *TCPServer) Close() (err error) {
    return t.server.Close()
}

Enter fullscreen mode Exit fullscreen mode

Whether you're new to Go or a comfortable Go programmer, bear with me whilst I break this down. First, we declare a new interface type named Server, this is so we can ensure all Clients use the same API, since in the next part of this series we'll build a UDPServer type. Adding interfaces to your application too early normally results in over-engineering, always consider YAGNI... "you ain't gonna need it". Next, we write a builder function that will return a Server implementation based on the chosen protocol. Finally, we create our TCPServer type that implicitly satisfies the Server interface because we have defined Run and Close methods. For now, all we've achieved is passing the test we defined. As the saying goes... Red, Green, Refactor.

With that in mind, let's write a test that covers handling output from the server. This time we're going to make use of table testing, we will run a series of test criteria inside a for-loop, meaning we can efficiently test different outcomes, like so:

func TestNETServer_Request(t *testing.T) {
    tt := []struct {
        test    string
        payload []byte
        want    []byte
    }{
        {
            "Sending a simple request returns result",
            []byte("hello world\n"),
            []byte("Request received: hello world")
        },
        {
            "Sending another simple request works",
            []byte("goodbye world\n"),
            []byte("Request received: goodbye world")
        },
    }

    for _, tc := range tt {
        t.Run(tc.test, func(t *testing.T) {
            conn, err := net.Dial("tcp", ":1123")
            if err != nil {
                t.Error("could not connect to TCP server: ", err)
            }
            defer conn.Close()

            if _, err := conn.Write(tc.payload); err != nil {
                t.Error("could not write payload to TCP server:", err)
            }

            out := make([]byte, 1024)
            if _, err := conn.Read(out); err == nil {
                if bytes.Compare(out, tc.want) == 0 {
                    t.Error("response did match expected output")
                }
            } else {
                t.Error("could not read from connection")
            }
        })
    }
}
Enter fullscreen mode Exit fullscreen mode

The tests above send a payload (byte slice) to our server and then attempt to read from the connection to see what the response was. By utilizing the bytes package we can test for an Index of a substring with bytes.Index or simply to test for a substring with bytes.Contains. If we want to test for a direct match between the expected response and actual response, we can use the approach outlined above using bytes.Compare(a, b), where 0 return value means a == b.

As the application currently stands, the tests will fail, so now we need to add the code that will actually handle input in an idiomatic way:

func (t *TCPServer) handleConnections() (err error) {
    for {
        conn, err := t.server.Accept()
        if err != nil || conn == nil {
            err = errors.New("could not accept connection")
            break
        }

        go t.handleConnection(conn)
    }
    return
}

func (t *TCPServer) handleConnection(conn net.Conn) {
    defer conn.Close()

    rw := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
    for {
        req, err := rw.ReadString('\n')
        if err != nil {
            rw.WriteString("failed to read input")
            rw.Flush()
            return
        }

        rw.WriteString(fmt.Sprintf("Request received: %s", req))
        rw.Flush()
    }
}
Enter fullscreen mode Exit fullscreen mode

So, what are we doing here? Well, we start by defining a new method receiver on TCPServer that is going to continually listen and accept any connection requests to the server. If there are any problems doing so, it breaks the loop and returns the error. If all goes to plan, the connection is passed to the second method, handleConnection, where we continually read messages and respond to the client. Using the go keyword allows us to handle each connection inside its own Goroutine. Most importantly, let's not forget to update our Run method to use the new handleConnections. Inside Run, replace the line: conn.Close() with return t.handleConnections(), since handleConnections also returns an error, we can use the method call as a return value.

If you run the tests now with go test -v -cover. You should see the following output:

=== RUN   TestNETServer_Running
-------- PASS: TestNETServer_Running (0.00s)
=== RUN   TestNETServer_Request
=== RUN   TestNETServer_Request/Sending_a_simple_request_returns_result
=== RUN   TestNETServer_Request/Sending_another_simple_request_works
-------- PASS: TestNETServer_Request (0.00s)
    --- PASS: TestNETServer_Request/Sending_a_simple_request_returns_result (0.00s)
    --- PASS: TestNETServer_Request/Sending_another_simple_request_works (0.00s)
PASS
coverage: 68.6% of statements
Enter fullscreen mode Exit fullscreen mode

Awesome, we have a working TCP server! We can be confident will return us the output, regardless of how simple the implementation is. Hopefully, if you're new to network programming in Go or testing in general, you can draw some inspiration from this tutorial. As I mentioned earlier, in the next installment of the series I'm going to look at adding a UDPServer and updating the tests to efficiently cover both network protocols. Thanks for reading, if you liked the content, please consider sharing!

image

The source code for this example can be found on my GitHub here: github.com/williamhgough/devtheweb-source

Top comments (1)

Collapse
 
davemcphee profile image
Alex Schmitz

This is great, thank you for this elegant example! But I think there's a bug, unless I'm wrong, in which case there's one line I don't understand.

In the Run() function, you accept a connection, then launch the handleConnections() func, which accepts connections as well. As far as I can tell, that first connection to hit this server is discarded, but subsequent ones, which are handled in handleConnections(), run fine.

In my implementation, I removed the inf loop in Run(), and just return t.handleConnections() - is this OK or am I missing something?