I will write a memcache server with tokio by using redis protocol. This memcache server only support 2 command from redis:
get <key>
set <key> <value>
The goal is, after we run the server, we can use the standard redis-cli
to connect to it and use get
& set
command.
Basic code structure
We utilize tokio
to write this high performance server software:
In Cargo.toml
:
[dependencies]
tokio = { version = "0.2", features = ["full"] }
In main.rs
:
use std::collections::HashMap;
use std::error::Error;
use tokio::net::{TcpListener, TcpStream};
use tokio::prelude::*;
use tokio::sync::Mutex;
use std::sync::Arc;
#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
let addr = "127.0.0.1:7777";
let mut listener = TcpListener::bind(addr).await?;
println!("Listen on {}", addr);
let dict = Arc::new(Mutex::new(HashMap::new()));
loop {
let (mut sock, _) = listener.accept().await?;
let dict = dict.clone();
tokio::spawn(async move {
// todo
});
}
}
The dict
variable is for saving the key-value pair later. This dict we wrapped it with Mutex
and Arc
, because it will be used in multi-threads in tokio.
Redis Protocol
Redis use \r\n
as separator in its protocol. To know the details of its protocol needs another post, so here we just introduce the basics:
For Simple Strings the first byte of the reply is "+"
For Errors the first byte of the reply is "-"
For Integers the first byte of the reply is ":"
For Bulk Strings the first byte of the reply is "$"
For Arrays the first byte of the reply is "*"
For example, for the get <key>
command, redis-cli will send the content like this (suppose our key is "hello" in this case):
*2
$3
get
$5
hello
The leading *2
indicates this is an array which contains 2 elements, the first one's length is 3, and the second one's length is 5.
Let's write 2 util functions to get the length integer and from the length integer get the followed string.
async fn read_till_crlf(stream: &mut TcpStream, skip: u8) -> Vec<u8> {
let mut ret: Vec<u8> = vec![];
let mut skip_num = skip;
loop {
let mut buf = [0; 1];
stream.read_exact(&mut buf).await.unwrap();
// LF's ascii number is 10
if skip_num == 0 && buf[0] == 10 {
break;
}
if skip_num > 0 {
skip_num -= 1;
} else {
ret.push(buf[0]);
}
}
// pop the last CR
ret.pop();
ret
}
async fn read_nbytes(stream: &mut TcpStream, nbytes: usize) -> Vec<u8> {
let mut ret: Vec<u8> = vec![0; nbytes];
stream.read_exact(&mut ret).await.unwrap();
ret
}
async fn get_next_len(stream: &mut TcpStream) -> usize {
let vlen = read_till_crlf(stream, 1).await;
let slen = String::from_utf8(vlen).unwrap();
let len:usize = slen.parse().unwrap();
len
}
async fn get_next_string(stream: &mut TcpStream) -> String {
let len = get_next_len(stream).await;
let vs = read_nbytes(stream, len).await;
// consume the followed \r\n
let _ = read_nbytes(stream, 2).await;
// build string and return
let s = String::from_utf8(vs).unwrap();
s
}
The get_next_len
will return the number from the stream like "*2\r\n" and "$3\r\n", etc. The get_next_string
will return the string, for example, for "$5\r\nhello\r\n", it will return "hello".
Handle unknown command and syntax error
Since we only support get <key>
and set <key> <value>
command, let's write other 2 util functions:
async fn handle_unknown(stream: &mut TcpStream) {
stream.write_all(b"-Unknown command\r\n").await.unwrap();
}
async fn handle_syntax_err(stream: &mut TcpStream) {
stream.write_all(b"-ERR syntax error\r\n").await.unwrap();
}
Handle GET command
async fn handle_get(stream: &mut TcpStream,
dict: &Arc<Mutex<HashMap<String, String>>>) {
let key = get_next_string(stream).await;
let map = dict.lock().await;
let s = match map.get(key.as_str()) {
Some(v) => {
format!("${}\r\n{}\r\n", v.len(), v)
},
None => {
"$-1\r\n".to_owned()
},
};
stream.write_all(s.as_bytes()).await.unwrap();
}
First we get the key, if key is in our dictionary, then return the value. Remember the returned value also need to follow the redis protocol.
Handle SET command
async fn handle_set(stream: &mut TcpStream,
dict: &Arc<Mutex<HashMap<String, String>>>) {
let key = get_next_string(stream).await;
let val = get_next_string(stream).await;
let mut map = dict.lock().await;
map.insert(key, val);
stream.write_all(b"+OK\r\n").await.unwrap();
}
For the set
command, we get the key then the value, then set it into our dictionary. We don't care if the key already exists: if it does, we will just overwrite it.
Put it together
Full code:
use std::collections::HashMap;
use std::error::Error;
use tokio::net::{TcpListener, TcpStream};
use tokio::prelude::*;
use tokio::sync::Mutex;
use std::sync::Arc;
#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
let addr = "127.0.0.1:7777";
let mut listener = TcpListener::bind(addr).await?;
println!("Listen on {}", addr);
let dict = Arc::new(Mutex::new(HashMap::new()));
loop {
let (mut sock, _) = listener.accept().await?;
let dict = dict.clone();
tokio::spawn(async move {
// get arg array length like *2, *3
let arg_len = get_next_len(&mut sock).await;
let cmd = get_next_string(&mut sock).await;
let cmd = cmd.to_lowercase();
if cmd == "get" {
if arg_len != 2 {
handle_syntax_err(&mut sock).await;
} else {
handle_get(&mut sock, &dict).await;
}
} else if cmd == "set" {
if arg_len != 3 {
handle_syntax_err(&mut sock).await;
} else {
handle_set(&mut sock, &dict).await;
}
} else {
handle_unknown(&mut sock).await;
}
});
}
}
async fn read_till_crlf(stream: &mut TcpStream, skip: u8) -> Vec<u8> {
let mut ret: Vec<u8> = vec![];
let mut skip_num = skip;
loop {
let mut buf = [0; 1];
stream.read_exact(&mut buf).await.unwrap();
// LF's ascii number is 10
if skip_num == 0 && buf[0] == 10 {
break;
}
if skip_num > 0 {
skip_num -= 1;
} else {
ret.push(buf[0]);
}
}
// pop the last CR
ret.pop();
ret
}
async fn read_nbytes(stream: &mut TcpStream, nbytes: usize) -> Vec<u8> {
let mut ret: Vec<u8> = vec![0; nbytes];
stream.read_exact(&mut ret).await.unwrap();
ret
}
async fn get_next_len(stream: &mut TcpStream) -> usize {
let vlen = read_till_crlf(stream, 1).await;
let slen = String::from_utf8(vlen).unwrap();
let len:usize = slen.parse().unwrap();
len
}
async fn get_next_string(stream: &mut TcpStream) -> String {
let len = get_next_len(stream).await;
let vs = read_nbytes(stream, len).await;
// consume the followed \r\n
let _ = read_nbytes(stream, 2).await;
// build string and return
let s = String::from_utf8(vs).unwrap();
s
}
async fn handle_get(stream: &mut TcpStream,
dict: &Arc<Mutex<HashMap<String, String>>>) {
let key = get_next_string(stream).await;
let map = dict.lock().await;
let s = match map.get(key.as_str()) {
Some(v) => {
format!("${}\r\n{}\r\n", v.len(), v)
},
None => {
"$-1\r\n".to_owned()
},
};
stream.write_all(s.as_bytes()).await.unwrap();
}
async fn handle_set(stream: &mut TcpStream,
dict: &Arc<Mutex<HashMap<String, String>>>) {
let key = get_next_string(stream).await;
let val = get_next_string(stream).await;
let mut map = dict.lock().await;
map.insert(key, val);
stream.write_all(b"+OK\r\n").await.unwrap();
}
async fn handle_unknown(stream: &mut TcpStream) {
stream.write_all(b"-Unknown command\r\n").await.unwrap();
}
async fn handle_syntax_err(stream: &mut TcpStream) {
stream.write_all(b"-ERR syntax error\r\n").await.unwrap();
}
Let's run it with cargo run
:
$ cargo run
Listen on 127.0.0.1:7777
In another terminal window, use redis-cli
to connect to it:
$ redis-cli -p 7777
127.0.0.1:7777> get hello
(nil)
127.0.0.1:7777> set hello world
OK
127.0.0.1:7777> get hello
"world"
127.0.0.1:7777> set hello world1 world2 world3
(error) ERR syntax error
127.0.0.1:7777> get hello
"world"
127.0.0.1:7777> set hello "world overwrite"
OK
127.0.0.1:7777> get hello
"world overwrite"
127.0.0.1:7777> command
(error) Unknown command
127.0.0.1:7777>
It works! Thanks to tokio
, with only around 100 lines code, we now have a multi-threads async memcache server. Even more, because we are using redis protocol, we can use the exisiting client library. For example, with Python's redis package (pip install redis
), we can get the value associate with "hello" we just set:
>>> import redis
>>> r = redis.Redis(port=7777)
>>> r.get("hello")
b'world overwrite'
Top comments (0)