Recap From Previous Article
On previous article about creating minimal working websocket server, we've been explored creating handshake, create client request, walkthrough WebSocket frame format, and echoed client request.
Article Outline
At this moment, we're going to refine our previous code to be more RFC 6455 compliant, specially about payload length frame fragment. We'll still stick to RFC 6455 focusing on payload length, discussing network byte order, and some code refinements besides payload length to be more serious engineering.
Why Correctness Matters? It Already Works!
Building minimal working code is fine for learning, but it is not enough - specially when it use on real world client such as our web browser or cloud services that use WebSocket as their way to communicate with their clients. Those real world technologies expect strict RFC 6455 compliance. And remember: just because it works does not mean it is correct. Small request handling deviations such as incorrect payload length encoding or missing mask enforcement might break compability or lead to undefined behavior. Correctness ensures that our server interoperates reabliy with any standard compliance.
What Was Missing Then?
Section 5.2 of RFC 6455 explains payload length. It says
The length of the "Payload data", in bytes: if 0-125, that is the
payload length. If 126, the following 2 bytes interpreted as a
16-bit unsigned integer are the payload length. If 127, the
following 8 bytes interpreted as a 64-bit unsigned integer (the
most significant bit MUST be 0) are the payload length. Multibyte
length quantities are expressed in network byte order. Note that
in all cases, the minimal number of bytes MUST be used to encode
the length, for example, the length of a 124-byte-long string
can't be encoded as the sequence 126, 0, 124. The payload length
is the length of the "Extension data" + the length of the
"Application data". The length of the "Extension data" may be
zero, in which case the payload length is the length of the
"Application data".
In english, it says
Payload-length fragment does not always represent actual request length. It just a flag/signal where you MUST read actual payload length.
For example:
- If client sents payload that has length exactly 64 bytes long, then value of payload length fragment is 64.
- If client sents payload that has exactly 126 bytes long, then server MUST read next 2 bytes to capture request payload length instead of read payload length fragment.
- If client sents payload that has length exactly 256 bytes long, then server MUST read next 8 bytes to capture request payload length.
In short:
- If request length < 125: actual payload length is payload length fragment itself
- If request length == 126: next 2 bytes encode actual payload length, server MUST NOT treat payload-length fragment as actual payload length; it MUST instead read extended payload-length fragment
- If request length == 127: next 8 bytes encode actual payload length, server MUST NOT treat payload-length fragment as actual payload length; it MUST instead read extended payload-length fragment
Let's flashback what we've done previously where we built server that echoes client request.
func (ws *ws) readRequest(requestSize int) (frame, error) {
data := make([]byte, requestSize)
_, err := ws.reader.Read(data)
if err != nil {
return frame{}, err
}
opcode := data[0] & 0xf
payloadLength := data[1] & 0x7f
return frame{
opcode: opcode,
payloadLength: int(payloadLength),
}, nil
}
// writeTextFrameResponse echoes server response as same as
// client request. This works normally when client
// request length is < 126
func (ws *ws) writeTextFrameResponse(frame frame) error {
// read masking key where masking key located
// 3rd - 6th byte of request, first 2 already read
// to retrieve opcode and payload length.
maskingKey := make([]byte, 4)
_, err := ws.reader.Read(maskingKey)
if err != nil {
return err
}
requestData := make([]byte, frame.payloadLength)
_, err = ws.reader.Read(requestData)
if err != nil {
return err
}
// unmasking request payload, refer to https://datatracker.ietf.org/doc/html/rfc6455#autoid-24
for i := 0; i < len(requestData); i++ {
requestData[i] = requestData[i] ^ maskingKey[i%4]
}
responseByte := []byte(requestData)
responseFrame := make([]byte, 2)
responseFrame[0] = 0x80 | frame.opcode
responseFrame[1] = byte(len(responseByte))
responseFrame = append(responseFrame, responseByte...)
_, err = ws.writer.Write(responseFrame)
if err != nil {
return err
}
err = ws.writer.Flush()
if err != nil {
return err
}
return nil
}
readRequest function reads payload length fragment but it does not read actual payload length. The problem is when frame read by writeTextFrameResponse, which only valid if payload length is < 125. If client sents request where it's payload is == 126, server will act with undefined behavior because when request reaches this function. Server will interprets next 4 bytes as masking key, while in fact next 2 bytes are extended payload length/actual payload length, then next 4 bytes from actual payload length are masking key.
maskingKey := make([]byte, 4)
_, err := ws.reader.Read(maskingKey)
if err != nil {
return err
}
server will read next 4 bytes as masking key. That is not correct, because RFC states if request payload length is exactly 126, next 2 bytes are actual payload length, and masking key is next 4 bytes from actual payload length. If we enforce request server to echo client's request, client will close connection with server immediately because server does not handle request frame correctly.
Network Byte Order
Section 5.2 explicitly states
Multibyte length quantities are expressed in network byte order.
What is network byte order and why we must follow this rule?
Network byte order is a convention that defines data arrangement when a device transfer data across a network, and it is always in big-endian.
Even your computer stores data in little-endian arrangement, when transfering data throughout network protocol (WebSocket, HTTP, TCP), they will rearrange in big-endian arrangement. Here is difference between big-endian and little-endian arragement.
- Big-endian: put largest byte (leftmost in order) first over the network
- Little-endian: put smallest byte (rightmost in order) first over the network
For example, if we have 0x12345678 and let's say we have x that stores the bytes in either big-endian or little-endian
-
xin big-endian format is0x12345678 -
xin little-endian format is0x78563412
This detail plays significant role in WebSocket extended payload length fragment which always uses big-endian encoding.
- If payload-length fragment == 126: next 2 bytes are actual payload length in big-endian arrangement
- If payload-length fragment == 127: next 8 bytes are actual payload length in big-endian arrangement
It might be trivial at first glance, but this arrangement is prominent for WebSocket server. Let's use our echoing WebSocket server as example.
Consider a client sends a request with payload length is exactly 256 bytes :
- Payload-length fragment:
0x7F(because 256 bytes payload length requires 16-bit payload extended length as actual payload length) - Extended payload-length fragment:
0x1 0x0
If we mistakenly encode extended payload-length in little-endian (0x0 0x1) the receiver interprets as 1 byte which corruputs the entire frame encoding.
Handle Extended Payload Length Properly
Now we are understand actual payload length, network byte order, and big-endian concept, it is our time to handle extended payload length properly. Let's begin with current readRequest function which accepts client's request and return request frame.
// Previous simplified version (for demonstration only)
func (ws *ws) readRequest(requestSize int) (frame, error) {
data := make([]byte, requestSize)
_, err := ws.reader.Read(data)
if err != nil {
return frame{}, err
}
opcode := data[0] & 0xf
payloadLength := data[1] & 0x7f
return frame{
opcode: opcode,
payloadLength: int(payloadLength),
}, nil
}
Earlier, our readRequest function only reads first 2 bytes and ignores additional bytes. Let's refactor that function to accept payload-length fragment defined by RFC 6455 that:
- reads
opcode - reads the mask bit
- reads payload-length fragment (the indicator that tells us the actual payload-length is inline, next 2 bytes or next 8 bytes)
- reads additional bytes as actual payload-length (next 2 or 8 bytes)
- returns metadata needed to read payload safely
func (ws *ws) readRequest(requestSize int) (frame, error) {
data := make([]byte, requestSize)
_, err := io.ReadFull(ws.reader, data)
if err != nil {
return frame{}, err
}
var actualPayloadLength int64
opcode := data[0] & 0xf
payloadIndicator := data[1] & 0x7f
isMasked := data[1]&0x80 != 0
frame := frame{}
// read payload length frame
switch {
case payloadIndicator < 126:
actualPayloadLength = int64(payloadIndicator)
case payloadIndicator == 126:
// actual payload length is next 2 byte
actualPayloadByte := make([]byte, 2)
_, err := io.ReadFull(ws.reader, actualPayloadByte)
if err != nil {
return frame, err
}
actualPayloadLength = int64(binary.BigEndian.Uint16(actualPayloadByte))
case payloadIndicator == 127:
// actual payload length is next 8 byte
actualPayloadByte := make([]byte, 8)
_, err := io.ReadFull(ws.reader, actualPayloadByte)
if err != nil {
return frame, err
}
actualPayloadLength = int64(binary.BigEndian.Uint64(actualPayloadByte))
}
frame.opcode = opcode
frame.payloadIndicator = int64(payloadIndicator)
frame.actualPayloadLength = actualPayloadLength
frame.isMasked = isMasked
return frame, nil
}
Note:
requestSizeis kept for simplicity since our earlier code used it. A proper WebSocket frame reader always reads exactly 2 bytes first.
Notice changes we've made on refactored version compared to previous:
- We defined
payloadIndicator, which is a flag that tells us where is actual payload-length. -
switchstatment that comparespayloadIndicatoragainst payload-length fragment defined by RFC 6455. - We handled each payload-length fragment based on RFC 6455 payload-length definition. We use
io.ReadFullto readnnext bytes if payload-length fragment is either 126 or 127. - We decoded actual payload length using
binary.BigEndian, because RFC 6455 requires all multibytes values transmitted in network byte order. - We use
actualPayloadLengthto store actual payload length whose value is decoded actual payload length (converted from big-endian bytes into anint64)
Echoing Request Back Correctly
Let's step back to previous writeTextFrameResponse that echoes client request
// writeTextFrameResponse echoes server response as same as
// client request. This works normally when client
// request length is < 126
func (ws *ws) writeTextFrameResponse(frame frame) error {
// read masking key where masking key located
// 3rd - 6th byte of request, first 2 already read
// to retrieve opcode and payload length.
maskingKey := make([]byte, 4)
_, err := ws.reader.Read(maskingKey)
if err != nil {
return err
}
requestData := make([]byte, frame.payloadLength)
_, err = ws.reader.Read(requestData)
if err != nil {
return err
}
// unmasking request payload, refer to https://datatracker.ietf.org/doc/html/rfc6455#autoid-24
for i := 0; i < len(requestData); i++ {
requestData[i] = requestData[i] ^ maskingKey[i%4]
}
responseByte := []byte(requestData)
responseFrame := make([]byte, 2)
responseFrame[0] = 0x80 | frame.opcode
responseFrame[1] = byte(len(responseByte))
responseFrame = append(responseFrame, responseByte...)
_, err = ws.writer.Write(responseFrame)
if err != nil {
return err
}
err = ws.writer.Flush()
if err != nil {
return err
}
return nil
}
Our old writeTextFrameResponse explicitly said that it only handles request that length is < 126. Now let's refactor that function expects:
- Close connection immediately when frame is not masked
- Close conenction immediately when request payload length overloads server capacity
- Calculates response size before build response frame on server instead of reflecting request size
- Build response frame based on response size
Here is writeTextFrameResponse revision that tailored expectations above:
const (
MAX_PAYLOAD_LENGTH = 262144
)
func (ws *ws) writeTextFrameResponse(frame frame) error {
if !frame.isMasked {
return ws.close(frame)
}
if frame.actualPayloadLength < 0 || frame.actualPayloadLength > MAX_PAYLOAD_LENGTH {
return ws.close(frame)
}
maskingKey := make([]byte, 4)
_, err := io.ReadFull(ws.reader, maskingKey)
if err != nil {
return err
}
requestData := make([]byte, frame.actualPayloadLength)
_, err = io.ReadFull(ws.reader, requestData)
if err != nil {
return err
}
// unmasking request payload, refer to https://datatracker.ietf.org/doc/html/rfc6455#autoid-24
for i := 0; i < len(requestData); i++ {
requestData[i] = requestData[i] ^ maskingKey[i%4]
}
responseByte := []byte(requestData)
if len(responseByte) <= 125 {
responseFrame := make([]byte, 2)
responseFrame[0] = 0x80 | frame.opcode
responseFrame[1] = byte(len(responseByte))
responseFrame = append(responseFrame, responseByte...)
_, err = ws.writer.Write(responseFrame)
if err != nil {
return err
}
}
if len(responseByte) >= 126 && len(responseByte) <= 65535 {
responseFrame := make([]byte, 2)
responseFrame[0] = 0x80 | frame.opcode
responseFrame[1] = 0x7e
actualPayloadFrameByte := make([]byte, 2)
binary.BigEndian.PutUint16(actualPayloadFrameByte, uint16(len(responseByte)))
responseFrame = append(responseFrame, actualPayloadFrameByte...)
responseFrame = append(responseFrame, responseByte...)
_, err = ws.writer.Write(responseFrame)
if err != nil {
return err
}
}
if len(responseByte) >= 65536 {
responseFrame := make([]byte, 2)
responseFrame[0] = 0x80 | frame.opcode
responseFrame[1] = 0x7f
actualPayloadFrameByte := make([]byte, 8)
binary.BigEndian.PutUint64(actualPayloadFrameByte, uint64(len(responseByte)))
responseFrame = append(responseFrame, actualPayloadFrameByte...)
responseFrame = append(responseFrame, responseByte...)
_, err = ws.writer.Write(responseFrame)
if err != nil {
return err
}
}
err = ws.writer.Flush()
if err != nil {
return err
}
return nil
}
func (ws *ws) close(frame frame) error {
if !frame.isMasked {
ws.conn.Close()
return errors.New("frame MUST be masked")
}
responseFrame := []byte{
0x88, // FIN + Close opcode
0x02, // payload length = 2
0x03, 0xE8, // close code 1000 in network byte order
}
_, err := ws.writer.Write(responseFrame)
if err != nil {
return err
}
return ws.writer.Flush()
}
Let's walkthrough what we've done:
- We close connection against client if request payload is not masked.
- We prevent DDoS by checking incoming request length.
- Read payload mask is unchanged
- Now we are using
actualPayloadLengthto iterate over request payload instead ofpayloadLength. On previous version, usingpayloadLengthto iterate over request payload is valid, because we handled payload length fragment that has value <= 125. Since we are handling 126 and 127 payload length fragment, we use extended/actual payload length stored on 2 next byte after masking key for 126 payload length fragment and stored on 8 next bytes after masking key for 127 payload length fragment - Now we build response frame based on their payload-length fragment. Don't forget that RFC states we must arrange extended payload length in big-endian arragement. Notice we are using
binary.BigEndian.PutUint16andbinary.BigEndian.PutUint64. Those functions is to build big-endian from actual payload length. The arrangement follows Section 5.2. - Proper closing connection against client if request is not masked using
closefunction. Remember that server MUST NOT its response.
Complete Code:
package main
import (
"bufio"
"crypto/sha1"
"encoding/binary"
"errors"
"io"
"log"
"net"
"strings"
b64 "encoding/base64"
)
const (
secWsKey = "Sec-WebSocket-Key"
connHeaderKey = "Connection"
connHeaderVal = "Upgrade"
upgradeHeaderKey = "Upgrade"
upgradeConnHeaderVal = "websocket"
)
const (
MAX_PAYLOAD_LENGTH = 262144
)
func main() {
port := ":8083"
listener, err := net.Listen("tcp", port)
if err != nil {
log.Fatal(err)
}
log.Println("running on ", port)
for {
conn, err := listener.Accept()
if err != nil {
log.Fatalln(err)
}
go handshake(conn)
}
}
func handshake(conn net.Conn) error {
var secWsAccept string
reader := bufio.NewReader(conn)
for {
header, err := reader.ReadString('\n')
if err != nil && !errors.Is(err, io.EOF) {
return err
}
if header == "\r\n" || header == "\n" {
break
}
secWsAcceptVal, err := readHTTPUpgradeHeaderRequest(header)
if err != nil {
return err
}
if secWsAcceptVal != "" {
secWsAccept = secWsAcceptVal
}
}
writer := bufio.NewWriter(conn)
upgradeResp := []string{
"HTTP/1.1 101 Web Socket Protocol Handshake",
"Server: go/echoserver",
"Upgrade: WebSocket",
"Connection: Upgrade",
"Sec-WebSocket-Accept: " + secWsAccept,
"", // required for extra CRLF
"", // required for extra CRLF
}
_, err := writer.Write([]byte(strings.Join(upgradeResp, "\r\n")))
if err != nil {
return err
}
err = writer.Flush()
if err != nil {
return err
}
ws := ws{
conn: conn,
reader: reader,
writer: writer,
}
return ws.handleRequest()
}
// read HTTP upgrade request, returns Sec-WebSocket-Accept header
// value if header is Sec-WebSocket-Accept. Otherwise, checking
// other upgrade header defined at https://datatracker.ietf.org/doc/html/rfc6455#autoid-4
func readHTTPUpgradeHeaderRequest(header string) (string, error) {
var secWsAccept string
headerKeys := strings.Split(header, ":")
headerKey := strings.TrimSpace(headerKeys[0])
switch {
case headerKey == upgradeHeaderKey:
uUpgradeVal := strings.TrimSpace(headerKeys[1])
if uUpgradeVal != upgradeConnHeaderVal {
return "", errors.New("upgrade header value is not websocket")
}
return "", nil
case strings.Contains(header, connHeaderKey):
cConnVal := strings.TrimSpace(headerKeys[1])
if cConnVal != connHeaderVal {
return "", errors.New("conenection header value is not upgrade")
}
return "", nil
case strings.Contains(header, secWsKey):
sSecWsVal := strings.TrimSpace(headerKeys[1])
secWsAccept = sSecWsVal + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
sha := sha1.New()
sha.Write([]byte(secWsAccept))
encSecWsAccept := sha.Sum(nil)
secWsAccept = b64.StdEncoding.EncodeToString(encSecWsAccept)
return secWsAccept, nil
}
return "", nil
}
type ws struct {
conn net.Conn
reader *bufio.Reader
writer *bufio.Writer
}
func (ws *ws) handleRequest() error {
defer ws.conn.Close()
for {
frame, err := ws.readRequestRFCCompliant(2)
if err != nil {
return err
}
switch frame.opcode {
case 1:
err := ws.writeTextFrameResponseRFCCompliant(frame)
if err != nil {
return err
}
case 2:
log.Println("binary frame")
case 8:
return ws.close(frame)
case 9:
ws.pong(frame)
default:
ws.close(frame)
}
}
}
func (ws *ws) readRequestRFCCompliant(requestSize int) (frame, error) {
data := make([]byte, requestSize)
_, err := io.ReadFull(ws.reader, data)
if err != nil {
return frame{}, err
}
var actualPayloadLength int64
opcode := data[0] & 0xf
payloadIndicator := data[1] & 0x7f
isMasked := data[1]&0x80 != 0
frame := frame{}
// read payload length frame
switch {
case payloadIndicator < 126:
actualPayloadLength = int64(payloadIndicator)
case payloadIndicator == 126:
// actual payload length is next 2 byte
actualPayloadByte := make([]byte, 2)
_, err := io.ReadFull(ws.reader, actualPayloadByte)
if err != nil {
return frame, err
}
actualPayloadLength = int64(binary.BigEndian.Uint16(actualPayloadByte))
frame.bigEndianPayloadFormat = append(frame.bigEndianPayloadFormat, actualPayloadByte...)
case payloadIndicator == 127:
// actual payload length is next 8 byte
actualPayloadByte := make([]byte, 8)
_, err := io.ReadFull(ws.reader, actualPayloadByte)
if err != nil {
return frame, err
}
actualPayloadLength = int64(binary.BigEndian.Uint64(actualPayloadByte))
frame.bigEndianPayloadFormat = actualPayloadByte
}
frame.opcode = opcode
frame.payloadIndicator = int64(payloadIndicator)
frame.actualPayloadLength = actualPayloadLength
frame.isMasked = isMasked
return frame, nil
}
func (ws *ws) writeTextFrameResponseRFCCompliant(frame frame) error {
if !frame.isMasked {
return ws.close(frame)
}
if !(frame.actualPayloadLength >= 0 && frame.actualPayloadLength <= MAX_PAYLOAD_LENGTH) {
return ws.close(frame)
}
maskingKey := make([]byte, 4)
_, err := io.ReadFull(ws.reader, maskingKey)
if err != nil {
return err
}
requestData := make([]byte, frame.actualPayloadLength)
_, err = io.ReadFull(ws.reader, requestData)
if err != nil {
return err
}
// unmasking request payload, refer to https://datatracker.ietf.org/doc/html/rfc6455#autoid-24
for i := 0; i < len(requestData); i++ {
requestData[i] = requestData[i] ^ maskingKey[i%4]
}
responseByte := []byte(requestData)
if len(responseByte) <= 125 {
responseFrame := make([]byte, 2)
responseFrame[0] = 0x80 | frame.opcode
responseFrame[1] = byte(len(responseByte))
responseFrame = append(responseFrame, responseByte...)
_, err = ws.writer.Write(responseFrame)
if err != nil {
return err
}
}
if len(responseByte) >= 126 && len(responseByte) <= 65535 {
responseFrame := make([]byte, 2)
responseFrame[0] = 0x80 | frame.opcode
responseFrame[1] = 0x7e
actualPayloadFrameByte := make([]byte, 2)
binary.BigEndian.PutUint16(actualPayloadFrameByte, uint16(len(responseByte)))
responseFrame = append(responseFrame, actualPayloadFrameByte...)
responseFrame = append(responseFrame, responseByte...)
_, err = ws.writer.Write(responseFrame)
if err != nil {
return err
}
}
if len(responseByte) >= 65536 {
responseFrame := make([]byte, 2)
responseFrame[0] = 0x80 | frame.opcode
responseFrame[1] = 0x7f
actualPayloadFrameByte := make([]byte, 8)
binary.BigEndian.PutUint64(actualPayloadFrameByte, uint64(len(responseByte)))
responseFrame = append(responseFrame, actualPayloadFrameByte...)
responseFrame = append(responseFrame, responseByte...)
_, err = ws.writer.Write(responseFrame)
if err != nil {
return err
}
}
err = ws.writer.Flush()
if err != nil {
return err
}
return nil
}
func (ws *ws) close(frame frame) error {
if !frame.isMasked {
ws.conn.Close()
return errors.New("frame MUST be masked")
}
responseFrame := []byte{
0x88, // FIN + Close opcode
0x02, // payload length = 2
0x03, 0xE8, // close code 1000 in network byte order
}
_, err := ws.writer.Write(responseFrame)
if err != nil {
return err
}
return ws.writer.Flush()
}
func (ws *ws) pong(frame frame) error {
if !frame.isMasked {
return ws.close(frame)
}
maskingKey := make([]byte, 4)
_, err := io.ReadFull(ws.reader, maskingKey)
if err != nil {
return err
}
responseFrame := make([]byte, 2)
responseFrame[0] = 0x8a
responseFrame[1] = 0x0
_, err = ws.writer.Write(responseFrame)
if err != nil {
return err
}
return ws.writer.Flush()
}
type frame struct {
opcode uint8
payloadIndicator int64
isMasked bool
actualPayloadLength int64
bigEndianPayloadFormat []byte
}
Code Refinement Spotlight
This section presents the refined WebSocket server implementation and highlights each improvement compared to the previous version. Before discussing the refinements, here is the original (unrefined) version of the code:
package main
import (
"bufio"
"crypto/sha1"
"errors"
"io"
"log"
"net"
"strings"
b64 "encoding/base64"
)
const (
secWsKey = "Sec-WebSocket-Key"
connHeaderKey = "Connection"
connHeaderVal = "Upgrade"
upgradeHeaderKey = "Upgrade"
upgradeConnHeaderVal = "websocket"
)
func main() {
port := ":8083"
listener, err := net.Listen("tcp", port)
if err != nil {
log.Fatal(err)
}
log.Println("running on ", port)
for {
conn, err := listener.Accept()
if err != nil {
log.Fatalln(err)
}
go handshake(conn)
}
}
func handshake(conn net.Conn) {
var secWsAccept string
reader := bufio.NewReader(conn)
for {
header, err := reader.ReadString('\n')
if err != nil && !errors.Is(err, io.EOF) {
log.Fatal(err)
}
if header == "\r\n" || header == "\n" {
break
}
secWsAcceptVal, err := readHTTPUpgradeHeaderRequest(header)
if err != nil {
log.Fatal(err)
}
if secWsAcceptVal != "" {
secWsAccept = secWsAcceptVal
}
}
writer := bufio.NewWriter(conn)
upgradeResp := []string{
"HTTP/1.1 101 Web Socket Protocol Handshake",
"Server: go/echoserver",
"Upgrade: WebSocket",
"Connection: Upgrade",
"Sec-WebSocket-Accept: " + secWsAccept,
"", // required for extra CRLF
"", // required for extra CRLF
}
_, err := writer.Write([]byte(strings.Join(upgradeResp, "\r\n")))
if err != nil {
log.Println(err)
return
}
err = writer.Flush()
if err != nil {
log.Println(err)
return
}
ws := ws{
conn: conn,
reader: reader,
writer: writer,
}
ws.handleRequest()
}
// read HTTP upgrade request, returns Sec-WebSocket-Accept header
// value if header is Sec-WebSocket-Accept. Otherwise, checking
// other upgrade header defined at https://datatracker.ietf.org/doc/html/rfc6455#autoid-4
func readHTTPUpgradeHeaderRequest(header string) (string, error) {
var secWsAccept string
headerKeys := strings.Split(header, ":")
headerKey := strings.TrimSpace(headerKeys[0])
switch {
case headerKey == upgradeHeaderKey:
uUpgradeVal := strings.TrimSpace(headerKeys[1])
if uUpgradeVal != upgradeConnHeaderVal {
return "", errors.New("upgrade header value is not websocket")
}
return "", nil
case strings.Contains(header, connHeaderKey):
cConnVal := strings.TrimSpace(headerKeys[1])
if cConnVal != connHeaderVal {
return "", errors.New("conenection header value is not upgrade")
}
return "", nil
case strings.Contains(header, secWsKey):
sSecWsVal := strings.TrimSpace(headerKeys[1])
secWsAccept = sSecWsVal + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
sha := sha1.New()
sha.Write([]byte(secWsAccept))
encSecWsAccept := sha.Sum(nil)
secWsAccept = b64.StdEncoding.EncodeToString(encSecWsAccept)
return secWsAccept, nil
}
return "", nil
}
type ws struct {
conn net.Conn
reader *bufio.Reader
writer *bufio.Writer
}
func (ws *ws) handleRequest() {
defer ws.conn.Close()
for {
frame, err := ws.readRequest(2)
if err != nil {
log.Fatal(err)
}
switch frame.opcode {
case 1:
err := ws.writeTextFrameResponse(frame)
if err != nil {
log.Fatal(err)
}
case 2:
log.Println("binary frame")
case 8:
log.Println("disconnected from client")
return
case 9:
frame.opcode = 10
default:
log.Println("invalid opcode")
}
}
}
func (ws *ws) readRequest(requestSize int) (frame, error) {
data := make([]byte, requestSize)
_, err := ws.reader.Read(data)
if err != nil {
return frame{}, err
}
opcode := data[0] & 0xf
payloadLength := data[1] & 0x7f
return frame{
opcode: opcode,
payloadLength: int(payloadLength),
}, nil
}
// writeTextFrameResponse echoes server response as same as
// client request. This works normally when client
// request length is < 126
func (ws *ws) writeTextFrameResponse(frame frame) error {
// read masking key where masking key located
// 3rd - 6th byte of request, first 2 already read
// to retrieve opcode and payload length.
maskingKey := make([]byte, 4)
_, err := ws.reader.Read(maskingKey)
if err != nil {
return err
}
requestData := make([]byte, frame.payloadLength)
_, err = ws.reader.Read(requestData)
if err != nil {
return err
}
// unmasking request payload, refer to https://datatracker.ietf.org/doc/html/rfc6455#autoid-24
for i := 0; i < len(requestData); i++ {
requestData[i] = requestData[i] ^ maskingKey[i%4]
}
responseByte := []byte(requestData)
responseFrame := make([]byte, 2)
responseFrame[0] = 0x80 | frame.opcode
responseFrame[1] = byte(len(responseByte))
responseFrame = append(responseFrame, responseByte...)
_, err = ws.writer.Write(responseFrame)
if err != nil {
return err
}
err = ws.writer.Flush()
if err != nil {
return err
}
return nil
}
type frame struct {
opcode uint8
payloadLength int
actualPayloadLength int
bigEndianPayloadFormat []byte
}
Let's review what we changed compared to our previous code as mentioned above:
| Aspect | Before Refinement | After Refinement | Why It Matters | Changes On This Article |
|---|---|---|---|---|
| Request Frame Handling | Loosely handle 1,2,8,9 opcodes | Implements RFC logic for close, ping/pong, and invalid request opcode | No silent feature and failure; Control frame handling | Proper opcode checking at handleRequest()
|
| Masking | Server ignores unsmasked client request frame | Server close connection against client immediately | RFC compliant | validate frame masking at writeTextFrameResponse()
|
| ping/pong | Server only set frame opcode without send pong response | Server sends pong response frame | Required for heartbeat/liveness | opcode checking at handleRequest() and pong() definition |
| Close Frame | Server only log and return when recieves close opcode | Server sends proper close frame (FIN=1, opcode=8, close code=1000) | RFC compliant closing handshke | opcode checking at handleRequest() and close() definition |
| Payload-length parsing | Only handles <= 125 when read request frame | Adds proper handling for 126 and 127 (next 2 bytes and 8 bytes) | Prevents truncated reads and incorrect frame interpretation | extended payload length checking at handleRequest()
|
| Payload length safety | No limit | Applies MAX_PAYLOAD_LENGTH
|
Prevents DDoS/memory abuse | actual payload length validation at handleRequest()
|
| Extended payload length | missing when read request frame | Read and stores next 2 bytes and 8 bytes in big-endian | Essential for frame reconstruction when echo message to client | Read extended payload length at handleRequest()
|
| Response frame encoding | Reusing request payload length when building response frame | Computes payload length | Avoids malformed frames | Building response frame at writeTextFrameResponse()
|
Read |
Uses Read
|
Uses io.ReadFull
|
Prevents partial-frame bugs | Read extended payload length at handleRequest() and building response frame at writeTextFrameResponse()
|
Closing Thought on Extended Payload Length
Extended payload handling is one of the most commonly mishandled parts of WebSocket framing, which will lead unexpected behavior or generate silent errors. In this article, this fixes issue where where we ignoring extended payload length which we ignore on previous article. The corrected implementation now parses the 7-bit payload indicator in the second byte WebSocket frame, and checking the next 2 bytes (126) or 8 (127) bytes to get extended payload length information from request frame.
Furthermore, we also discuss more complex concepts like big-endian decoding, safe reads with io.ReadFull, strict mask enforcements, and proper closing connection against intentionalally invalid or malformed request frame.
With this fundamental in place, other WebSocket communication concepts like fragmentation, control-frame handling, and life-cycle management become more easier. And remember that just because it works, it does not mean correct.
Top comments (0)