DEV Community

chh
chh

Posted on

I made a Copilot in Rust 🦀 , here is what I have learned... (as a TypeScript dev)

My article Code Llama as Drop-In Replacement for Copilot Code Completion receiving lots of positive feedbacks. Since then, I made few other attempt to improve the copilot server.

In terms of performance, I made a PR in exllamav2 on a copilot server, using exllama2's super fast CUDA custom kernel.

To improve the completion quality, I tried few other LLMs such as replit-code-v1_5-3b, long_llama_code_7b, CodeShell-7B and stablelm-3b-4e1t.

As a practicing developer + a GPU poor without access to H100 clusters, to contribute in the era of AI Wild West, I'm more interesting to improve the ergonomics of the copilot server.

Hugging Face's candle, a minimalist ML framework for Rust, looks super interesting. So I started to create a minimalist copilot server in Rust 🦀.

Before you continue, note that exllamav2 version, python + CUDA is still much faster then the Rust version. This article is mostly for those interested in learning Rust, and want to learn the programming by building a fun project.

Essentially, this's a Build Your Own Copilot (in Rust 🦀) tutorial, the code is intended to be educational. If you just want to try the final product oxpilot:

brew install chenhunghan/homebrew-formulae/oxpilot
Enter fullscreen mode Exit fullscreen mode

and starts the copilot server

ox serve
Enter fullscreen mode Exit fullscreen mode

or chat with the LLM

ox hi in Japanese
Enter fullscreen mode Exit fullscreen mode

We will be using axum as the web framework, candle for text inferencing, clap for cli arguments parsing and tokio as the asynchronous runtime.

Table Of Contents

(Some sections are still WIP)

If you are already familiar with Rust at some levels, for example feeling comfortable with Rust's ownership/borrowing but not the async world, I suggest to jump to Async section.

If you are already familiar with async Rust, you can go directly to Hands-on which introduces some design patterns you might find useful, or just go to the Github oxpilot project where everything is open-sourced.

Please expect some, if not many, human errors. I documented my learning process hoping that can help someone on the internet, which likes to build a exciting project when learning a new language.

Thanks jihchi for reviewing the draft of this article.

Books and References

This article is self-contained, which means it should have all you need to know to read the source code in oxpilot.

However it's not possible to cover everything in each section. I try to provide references at the end of each sections, highly recommend to read the The Rust Programming Language if you haven't.

Tracing is batteries-included console.log

console.log is a powerful tool in TypeScript, you can print whatever you want, thus console.log is super useful for debugging. The Rust equivalent is print!, Rust by Example is an excellent document if you want get started quickly to use print!.

However, print! blocks stdio, and it's better to lock stdio and unlock manually, which is tedious.

Luckily we have alternative, Tracing is an awesome project by tokio team, as a TypeScript developer, I feel like home using tracing.

info!("Hello! Rust!");
info!("Print var: {:?}", var);
Enter fullscreen mode Exit fullscreen mode

What is {:?} ?

You might wonder what is {:?} in the code block.

info!("Print var: {:?}", var);
Enter fullscreen mode Exit fullscreen mode

{:?} is for printing struct (like Object in TypeScript). Alternatively {:#?} can pretty print (), see more, think of it like console.log(JSON.stringify(object,null,2)).

Measure performance

Tracing is an awesome for logging performance metrics, for example, if I want to measure how long awesome() took.

async fn awesome() {}
awesome().instrument(tracing::info_span!("awesome")).await;
Enter fullscreen mode Exit fullscreen mode

Prints super usefully messages, which tells when we start invoking awesome(), at which line, and which thread it was on, and how long it took to execute the function.

2023-10-22T09:01:13.128553Z INFO ThreadId(01) awesome src/main.rs:172: enter
2023-10-22T09:01:13.128569Z INFO ThreadId(01) awesome src/main.rs:172: close time.busy=15.3µs time.idle=3.96µs
Enter fullscreen mode Exit fullscreen mode

Feeling Safe

Rust is safe by default, the safe often refers to memory-safety. However, from my experience, Rust makes you feel safe shipping to production...once the code compiles.

If you ever wrote a line of code in JavaScript, and then switch to TypeScript, you probably knows what I mean by "feeling safe".

TypeScript protects us from TypeError: Cannot read property '' of undefined at compile time, Rust is like TypeScript with ultra strict mode on which protect us, developers from making mistakes at compile time.

Rust makes pull requests easier to review and increase the confident of shipping to production, the compiler error messages might seen overwhelming, just like TypeScript errors at the beginning.

However , if you ever under the stress of recovering production servers, you will know that learning to resolve the compile time error is better then resolving runtime exceptions.

To embrace the Rust safety net, immutability and ownership are two key concepts to understand.

Variables are immutable by default

"Immutable by default" means once data created, they can't not be mutated, most will agree that immutable data makes your code better.

For example, a seasoned TypeScript developer probably knows the benefits of using const, const makes the code intend explicit when you try to mutate the value.

const x = 5;
x = 1; // Cannot assign to 'x' because it is a constant.
Enter fullscreen mode Exit fullscreen mode

In Rust, variables are immutable by default and only mutable if you explicitly declare as mutable.

fn main() {
    let x = 5; // this does not compile, 
    x = 6; // explicit `let mut x` to make mutation possible.
}
Enter fullscreen mode Exit fullscreen mode

The book's Variables and Mutability has comprehensive explanation on mutability in Rust.

You should not move! Ownership

From a language with a garbage collector, the following code looks natural, we try to create string2 by referencing string1:

fn main() {
  let s1 = String::from("hello");
  let s2 = s1;
  println!("{}", s1);
}
Enter fullscreen mode Exit fullscreen mode

However, the code does not compile, the compiler said you have moved s1.

11 |   let s1 = String::from("hello");
   |       -- move occurs because `s1` has type `String`, which does not implement the `Copy` trait
12 |   let s2 = s1;
   |            -- value moved here
13 |   println!("{}", s1);
   |                  ^^ value borrowed here after move
Enter fullscreen mode Exit fullscreen mode

This might be the first, and continuously frustrating compiler error message when starting Rust.

You should not pass!

Rust does not ship with a garbage collector, which means it does not know (at runtime) when to drop the value from memory when you don't need the value anymore.

To archive this goal, Rust introduce the ownership checker, to make the developer mark the value when the rest of the code doesn't need it. Ownership checker helps you to manage memory at compile time, so we don't need to ship the code with a garbage collector that collect, and drop unused values from the memory in the runtime.

value **moved** here in the above example code is telling that the code is violating the ownership rules, which are

  1. Each value in Rust has an owner.
  2. There can only be one owner at a time.
  3. When the owner goes out of scope, the value will be dropped.

The compiler is telling: Hey! s1 is the owner of String::from("hello"), however, you have moved the ownership from s1 to s2, since you don't need the s1, compiler dropped s1, therefore, you should not use it again
in println!!

fn main() {
  let s1 = String::from("hello");
  let s2 = s1; // ownership moved from s1 to s2
  println!("{}{}", s1); // s1 is dropped, why you are still using it?
}
Enter fullscreen mode Exit fullscreen mode

If you are from TypeScript world (or any language with a garbage collector), ownership might looks foreign, however, learning ownership checker makes you aware of memory allocation.

Ownership and Scope

Let's review the ownership rules again, and get deeper into the third rule.

  1. Each value in Rust has an owner.
  2. There can only be one owner at a time.
  3. When the owner goes out of scope (the curly brackets {}), the value will be dropped.

In the following example, the compiler stops us at the second do_something() call, because we violate the ownership rule by moving owner into do_something and try to use owner again.

This does not compile:

fn main() {
    let owner = String::from("value");
    // we took the ownership of "value" from `owner` and
    // "value" is dropped at the end of the `do_something` function
    // thus the variable `owner` does not own it anymore
    do_something(owner);
    // use of moved value: `owner` value used here after move
    print!("{}", owner);
}

fn do_something(_: String) {
    // 
}
Enter fullscreen mode Exit fullscreen mode

playground

print!("{}", owner) violates the ownership rule because we have already move the owner into do_something(owner)'s
scope, therefore, after the the do_something(owner) execution is finished, the owner is out of scope, the owner is dropped and we can't use it anymore.

Borrow

To get around ownership rules, borrowing to rescue.

Borrowing is using reference syntax (&) to make Rust compiler knows we are just borrowing instead of taking ownership, borrowing using reference to make a promise that we are justing temporarily borrowing the value, we do not intend to take the ownership, and return the value when don't need it anymore.

Compiled

fn main() {
    let owner = String::from("value");
    // `do_something` borrows `"value"` from `owner`
    do_something(&owner);
    // No more error!
    print!("{}", owner);
}

fn do_something(_: &String) {
    // "value" is NOT dropped at the end of the function
    // because we are just borrowing (`&String`) not taking the ownership
}
Enter fullscreen mode Exit fullscreen mode

playground

Just like the ownership, the borrowing has a set of rules, these rules are the like contracts you made when you borrow something from someone else.

Areal world example for analogue: you want to borrow a book "Rust for Rustaceans" from a friend, to keep the friendship, you made a contract (a verbal promise: "I will return the borrowed book back to you in one month"), the contract needs to follow the borrowing rules:

  1. At any given time, you can have as many immutable reference you want but only one mutable reference.
  2. Reference must be referencing to a value that is valid (disallow referencing to a dropped value).

It's ok if ownership and borrowing still seems blur, the book's understanding ownership chapter is the best read, and you will get familiar with ownership rules soon after passing data and compiler yelling at you from time to time.

If you are a busy developer, Let's Get Rusty's The Rust Survival Guide is a great way to crash into ownership rules quickly.

Asynchronous

Before we start this section, let's pin the definitions of terminology.

Terminology

Async is a feature in a programming language intended to provide opportunities for the program to execute a unit of computation while waiting another unit of computation to complete.

Parallelism

The program executes units of computation at the same time, simultaneously, for example running two computations in two different cores of CPU.

Concurrency

The program process units of computation, executes them one by one, and yield from a unit to another unit quickly when a unit makes progress, the program yields between units quickly, as if the program executes units at the same time (but it's not simultaneously) ref, for the single-threaded Node.js runtime.

Task (Green Thread)

A task is for some computation running in a parallel or concurrent system. In this article, the term task refer to asynchronous green thread that is not a OS thread but a unit of execution managed by the async runtime.

Runtime (the Task Runner)

Node.js is single-threaded, asynchronous runtime, the program can process tasks asynchronously, however, the program is not processing the tasks in parallel, because Node.js is single-threaded.

To process tasks asynchronously in Rust, the developer needs to setup a task runner. The main function (think of it like index.ts), which is the entry point of a Rust program, is always synchronous, the developer needs to setup the runtime to be able to run asynchronous tasks in Rust.

The following code uses futures::executor as the async task runner.

fn main() {
    // the async task runner.
    futures::executor::block_on(do_something());
}

// An async task
async fn do_something() {
    //
}
Enter fullscreen mode Exit fullscreen mode

In Rust, you are free to choose any async runtimes, like in TypeScript, we have node.js, bun and deno. In Rust we have tokio, async-std, smol and futures, these runtime can be single threaded, like node.js which runs tasks concurrently, or multiple threaded that is true parallelism.

You may find these video useful to understand the async/await in Rust.

Ownership and Async

In You should not move! Ownership we discuss the ownership rules, and in Borrow section, we discuss how to get over ownership rules by borrowing.

In async Rust, no matter you are using a single-threaded concurrent green thread runtime, or distributing computation to multiple OS threads (parallelism), the ownership rules always apply. In the async Rust, the ownership rules are preventing data race in concurrent programming or parallel programming in Rust. (Also known as fearless concurrency)

Remember the ownership rules?

  1. Each value in Rust has an owner.
  2. There can only be one owner at a time.
  3. When the owner goes out of scope the value will be dropped.

Pay special attention of at a time, it is how ownership help us avoiding data race when running computation at the same time (= asynchronous).

Let's look at the synchronous version again, in previous example this failed to compile...

fn main() {
    let owner = String::from("value");
    do_something(owner);
    // use of moved value: `owner` value used here after move
    print!("{}", owner);
}

fn do_something(_: String) {
    // 
}
Enter fullscreen mode Exit fullscreen mode

playground
...because the code does not follow the ownership rules, that is, both do_something() took ownership of String::from("hello"), but Rust compiler only allows one ownership at a time. To protect us from forgetting deallocating memory, the owner is moved into the fist do_something(owner), and we can't compile the code because this error use of moved value:ownervalue used here after.

3 |     do_something(owner);
  |                  ----- value moved here
4 |     // use of moved value: `owner` value used here after move
5 |     print!("{}", owner);
  |                  ^^^^^ value borrowed here after move
Enter fullscreen mode Exit fullscreen mode

We can get over this by borrowing(&)

fn main() {
    let owner = String::from("value");
    // use & to reference owner
    do_something(&owner);
    // We can still use owner after
    print!("{}", owner);
}

fn do_something(_: &String) {
    // 
}

Enter fullscreen mode Exit fullscreen mode

playground

The same ownership rules apply to asynchronous Rust, let's look at parallelism version, which spawns OS threads running code simultaneously:

use std::thread;

fn main() {
    let owner = String::from("value");
    thread::spawn(|| {
        do_something(&owner);
    });
}
fn do_something(_: &String) {
    // 
}
Enter fullscreen mode Exit fullscreen mode

playground

We knew that we need to use borrowing(&) to avoid taking ownership when calling do_something(&owner). However, the compiler still reject, it says:

closure may outlive the current function, but it borrows owner, which is owned by the current function

This very compiler error is telling that the borrowing of owner might be referenced to a value, outside of the thread closure, at a time, when the value is dropped, violating this rule we discussed in borrowing.

  1. Reference must be referencing to a value that is valid (disallow referencing to a dropped value).

To give this outlive error more context, try to run this code in the playground.

use std::thread;

fn main() {
    thread::spawn(|| {
        print!("from thread");
    });
    print!("from main");
}
Enter fullscreen mode Exit fullscreen mode

You might be surprised that there is only from main in the console. It's because Rust's thread implementation in the std allows the created threads to outlive the thread created them, in other words, the parent thread (in our case the main()) created the child thread, the child thread created via thread::spawn might outlived the parent (the main()).

That's the reason why you see from main in the console, execution of || print!("from thread") outlived the execution of main.

If we step back, and think at a higher degree of borrowing in threads:

use std::thread;

fn main() {
    let owner = String::from("value");
    thread::spawn(|| {
        // we borrow owner, but the borrowed value (`owner`)
        // might be dropped in main(), that is the `&` might point
        // to a dropped value
        do_something(&owner);
    });
}
fn do_something(_: &String) {
    // 
}
Enter fullscreen mode Exit fullscreen mode

playground

We are running code simultaneously from main and a thread, at the same time, the compiler stops us by telling us that the "closure (in the thread) may outlive the current function but it borrows owner, which is owned by the main(), we shouldn't do this because we might be referenced to owner when it is invalid in the parent thread (main()).

The same outlive problem can be observed in concurrent code, even if in most concurrent runtimes, code execution is not in OS threads but in tasks:

/*
[dependencies]
tokio = { version = "1.32.0", features = ["full"] }
*/

#[tokio::main]
async fn main() {
    let owner = String::from("value");

    tokio::spawn(do_something(&owner));
}

async fn do_something(_: &String) {
    //
}
Enter fullscreen mode Exit fullscreen mode

rustexplorer

This code block failed with similar error message "owner" does not live long enough.

To get over this Threads Don't Borrow error, that is, to get over the ownership rule that disallow referencing a value from parent to children threads/tasks. We have few solutions:

  1. move the value into the thread (read more)
use std::thread;

fn main() {
    let owner = String::from("value");
    // `move` moving the `owner` into the spawned thread
    thread::spawn(move || {
        do_something(&owner);
    });
}
fn do_something(_: &String) {
    // 
}
Enter fullscreen mode Exit fullscreen mode

playground

  1. Use scoped thread, which exits before the parent thread (main) exits.
use std::thread;

fn main() {
    let owner = String::from("value");
    // scoped thread alway exists before the main thread exists
    // therefore we can use reference to pointing to `owner`
    thread::scope(|_| {
        do_something(&owner);
    });
}
fn do_something(_: &String) {
    // 
}
Enter fullscreen mode Exit fullscreen mode

playground

  1. "Do not communicate by sharing memory; instead, share memory by communicating" as in the Go language documentation. We will dive into this in the actor section.
  2. Atomic Reference Counting (Arc<T>) and Mutual Exclusion (Mutex<T>).

We will dive in to the (Arc<T>) in the next section.

Share States in Async Program: Arc and Mutex

Sharing a state in an async program can be a challenge. Ownership rules only allow a value to have a owner at a time. We can't use borrowing because the compiler does not know will the borrower in the thread/task pointing to a dropped value at a time.

To solve this problem, we can use Arc (Atomic Reference Counting).

Arc is safe to use to share the state across multiple threads/tasks. To wrap a data into Arc to have multiple copies of the same data:

use std::thread;
use std::sync::Arc;

fn main() {
    let arc = Arc::new(String::from("value"));
    thread::spawn(|| {
        do_something(arc);
    });
}
fn do_something(_: Arc<String>) {
    // 
}
Enter fullscreen mode Exit fullscreen mode

playground

Arc allows safe read to the inner data across threads, it's similar to borrowing but for asynchronous code blocks.

However, Arc only allows read, to enable thread to write to the inner data. The data needs to be handled with proper locking mechanism, that is the (Mutex<T>).

Mutex<T> (reads: mutual exclusion) will block threads waiting for the lock to become available. When calling lock() on a thread, the thread will become the only thread that can access the data, Mutex<T> blocks other threads from access the data, therefore, it's safe to mutate the data while the lock has not been unlocked.

To safely mutate the data we share with state:

use std::thread;
use std::sync::{Arc, Mutex};

fn main() {
    let inner_data = String::from("Hello ");
    let mutex = Arc::new(Mutex::new(inner_data));
    let mutex_clone = mutex.clone();
    thread::spawn(move || {
        let mut inner_data = mutex.lock().unwrap();
        inner_data.push_str(" world (once)!")
    });
    thread::spawn(move || {
        let mut inner_data = mutex_clone.lock().unwrap();
        inner_data.push_str(" world (twice)!")
    });
}
Enter fullscreen mode Exit fullscreen mode

playground

We will dive deeper on how to use Arc<Mutex<_>> to share the mutable state in section Share Memory Arc<Mutex<_>>.

To learn more on sharing state:

Hands-On

In the following sections, we will start building the copilot server.

Server-Sent Events (SSE) Server

In this PR, we add the endpoint for the copilot client

From Code Llama as Drop-In Replacement for Copilot Code Completion we knew that a copilot server is essentially a HTTP server that accepts a request with a prompt and return JSON chucks in Server-Sent Events (SSE). Let's try to specify the SSE endpoint and creating an Server-Sent Events (SSE) Server using axum.

  • The URL path of the endpoint. /v1/engines/:engine/completions
  • The endpoint should accept a POST request.
  • The endpoint takes a path parameter (:engine) and a request body.
  • The endpoint return a SSE stream of text chucks (Content-Type: text/event-stream).

Since this endpoint is almost identical to OpenAI's completions endpoint, we can use curl to see the input's input (request body) and the output (SSE text chucks)

curl https://api.openai.com/v1/completions \
  -H "Content-Type: application/json" \
  -d '{
    "model": "gpt-3.5-turbo-instruct",
    "prompt": "Say this is a test",
    "max_tokens": 7,
    "temperature": 0,
    "stream": true
  }'
# chuck 0
data: {"choices":[{"text":"This ","index":0,"logprobs":null,"finish_reason":null}],
"model":"gpt-3.5-turbo-instruct", "id":"...","object":"text_completion","created":1}
# chuck 1
data: {"choices":[{"text":"is ","index":0,"logprobs":null,"finish_reason":null}],
"model":"gpt-3.5-turbo-instruct", "id":"...","object":"text_completion","created":1}
# chuck 2
data: {"choices":[{"text":"a ","index":0,"logprobs":null,"finish_reason":null}],
"model":"gpt-3.5-turbo-instruct", "id":"...","object":"text_completion","created":1}
# chuck with `"finish_reason":"stop"`
data: {"choices":[{"text":"test.","index":0,"logprobs":null,"finish_reason":"stop"}],
"model":"gpt-3.5-turbo-instruct", "id":"...","object":"text_completion","created":1}
# end of SSE event stream
data: [DONE]
Enter fullscreen mode Exit fullscreen mode

BDD (Behaviour-Driven Development) the Endpoint

We will use reqwest_eventsource and its friends in the test to act as the client, which send request to our endpoint /v1/engines/:engine/completions and assert the response is what we expected. Since reqwest_eventsource and friends are not used in our final binary, let's add them in dev-dependencies in Cargo.toml.

[dev-dependencies]
reqwest = { version = "0.11.22", features = ["json", "stream", "multipart"] }
reqwest-eventsource = "0.5.0"
eventsource-stream = "0.2.3"
Enter fullscreen mode Exit fullscreen mode

Add the dummy handler, and axum's router to route requests to POST /v1/engines/:engine/completions

use axum::{routing::post, Router};

async fn completion() -> &'static str {
    "Hello, World!"
}

fn app() -> Router {
    Router::new()
        .route("/v1/engines/:engine/completions", post(completion))
}
Enter fullscreen mode Exit fullscreen mode

Translate the spec into the test:

#[cfg(test)]
mod tests {
    // imports are only for the tests
    // ...

    /// `super::*` means "everything in the parent module"
    /// It will bring all of the test module’s parent’s items into scope.
    use super::*;
    /// A helper function that spawns our application in the background
    /// and returns its address (e.g. http://127.0.0.1:[random_port])
    async fn spawn_app(host: impl Into<String>) -> String {
        let _host = host.into();
        // Bind to localhost at the port 0, which will let the OS assign an available port to us
        let listener = TcpListener::bind(format!("{}:0", _host)).await.unwrap();
        // We retrieve the port assigned to us by the OS
        let port = listener.local_addr().unwrap().port();

        let _ = tokio::spawn(async move {
            let app = app();
            axum::serve(listener, app).await.unwrap();
        });

        // We return the application address to the caller!
        format!("http://{}:{}", _host, port)
    }

    /// The #[tokio::test] annotation on the test_sse_engine_completion function is a macro.
    /// Similar to #[tokio::main] It transforms the async fn test_sse_engine_completion()
    /// into a synchronous fn test_sse_engine_completion() that initializes a runtime instance
    /// and executes the async main function.
    #[tokio::test]
    async fn test_sse_engine_completion() {
        let listening_url = spawn_app("127.0.0.1").await;
        let mut completions: Vec<Completion> = vec![];
        let model_name = "code-llama-7b";
        let body = serde_json::json!({ ... });

        let time_before_request = SystemTime::now()
            .duration_since(UNIX_EPOCH)
            .unwrap()
            .as_secs();
        let mut stream = reqwest::Client::new()
            .post(&format!(
                "{}/v1/engines/{engine}/completions",
                listening_url,
                engine = model_name
            ))
            .header("Content-Type", "application/json")
            .json(&body)
            .send()
            .await
            .unwrap()
            .bytes_stream()
            .eventsource();

        // iterate over the stream of events
        // and collect them into a vector of Completion objects
        while let Some(event) = stream.next().await {
            match event {
                Ok(event) => {
                    // break the loop at the end of SSE stream
                    if event.data == "[DONE]" {
                        break;
                    }

                    // parse the event data into a Completion object
                    let completion = serde_json::from_str::<Completion>(&event.data).unwrap();
                    completions.push(completion);
                }
                Err(_) => {
                    panic!("Error in event stream");
                }
            }
        }
        // The endpoint should return at least one completion object
        assert!(completions.len() > 0);

        // Check that each completion object has the correct fields
        // note that we didn't check all the values of the fields because
        // `serde_json::from_str::<Completion>` should panic if the field 
        // is missing or in unexpected format
        for completion in completions {
            // id should be a non-empty string
            assert!(completion.id.len() > 0);
            assert!(completion.object == "text_completion");
            assert!(completion.created >= time_before_request);
            assert!(completion.model == model_name);

            // each completion object should have at least one choice
            assert!(completion.choices.len() > 0);

            // check that each choice has a non-empty text
            for choice in completion.choices {
                assert!(choice.text.len() > 0);
                // finish_reason should can be None or Some(String)
                match choice.finish_reason {
                    Some(finish_reason) => {
                        assert!(finish_reason.len() > 0);
                    }
                    None => {}
                }
            }

            assert!(completion.system_fingerprint == "");
        }
    }
}
Enter fullscreen mode Exit fullscreen mode

Run the tests by cargo test, the tests failed, because we haven't implemented the completion()

Add the endpoint, to pass the tests, the endpoint need to response with chucks of SSE struct, let's first fake the values in the struct first, we will connect the endpoint to llm later! It's important to stabilise the HTTP interface first.

use async_stream::stream;
use axum::response::sse::{Event as SseEvent, KeepAlive, Sse};
use axum::Json;
use futures::stream::Stream;
use oxpilot::types::{Choice, Completion, CompletionRequest, Usage};
use serde_json::{json, to_string};
use std::convert::Infallible;
use std::time::{SystemTime, UNIX_EPOCH};

// Reference: https://github.com/tokio-rs/axum/blob/main/examples/sse/src/main.rs
pub async fn completion(
    // `Json<T>` will automatically deserialize the request body to a type `T` as JSON.
    Json(body): Json<CompletionRequest>,
) -> Sse<impl Stream<Item = Result<SseEvent, Infallible>>> {
    // `stream!` is a macro from [`async_stream`](https://docs.rs/async-stream/0.3.5/async_stream/index.html) 
    // that makes it easy to create a `futures::stream::Stream` from a generator.
    Sse::new(stream! {
        yield Ok(
          // Create a new `SseEvent` with the default settings.
          // `SseEvent::default().data("Hello, World!")` will return `data: Hello, World!` as the event text chuck.
          SseEvent::default().data(
            // Serialize the `Completion` struct to JSON and return it as the event text chunk.
            to_string(
              // json! is a macro from serde_json that makes it easy to create JSON values from a struct.
              &json!(
                Completion {
                  id: "cmpl-".to_string(),
                  object: "text_completion".to_string(),
                  created: SystemTime::now()
                      .duration_since(UNIX_EPOCH)
                      .unwrap()
                      .as_secs(),
                  model: body.model.unwrap_or("unknown".to_string()),
                  choices: vec![Choice {
                      text: " world!".to_string(),
                      index: 0,
                      logprobs: None,
                      finish_reason: Some("stop".to_string()),
                  }],
                  usage: Usage {
                      prompt_tokens: 0,
                      completion_tokens: 0,
                      total_tokens: 0
                  },
                  system_fingerprint: "".to_string(),
                }
              )).unwrap()
            )
        );
    })
    .keep_alive(KeepAlive::default())
}
Enter fullscreen mode Exit fullscreen mode

That's it, the tests should pass now.

running 1 test
test tests::test_sse_engine_completion ... ok

test result: ok. 1 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.04s
Enter fullscreen mode Exit fullscreen mode

Alternatively, we can test the copilot e2e:

cargo run
Enter fullscreen mode Exit fullscreen mode

will bind the server at port 6666, because we have these in main:

#[tokio::main]
async fn main() {
    // ..
    let listener = tokio::net::TcpListener::bind("0.0.0.0:6666").await.unwrap();
    let app = app();
    axum::serve(listener, app).await.unwrap();
}
Enter fullscreen mode Exit fullscreen mode

Edit the settings.json in VSCode.

"github.copilot.advanced": {
   "debug.overrideProxyUrl": "http://localhost:6666"
}
Enter fullscreen mode Exit fullscreen mode

And open any file, we should see ghost texts with world! that is from our copilot server running at port 6666.

Builder Pattern

In this section, we will implement a new struct (like an Object), LLMBuilder in llm.rs, and use the struct in our binary's entry point main.rs.

To be able to use components from llm.rs in main.rs we layout our files like

└── src
    ├── lib.rs
    ├── llm.rs
    └── main.rs
Enter fullscreen mode Exit fullscreen mode

in llm.rs, we make LLMBuilder public.

// llm.rs
pub struct LLMBuilder {}
Enter fullscreen mode Exit fullscreen mode

Declare new module (mod) in lib.rs)

// lib.rs
pub mod llm; // rust will resolve to `./llm.rs`
Enter fullscreen mode Exit fullscreen mode

and use the module in main.rs

use oxpilot::llm::LLMBuilder;
fn main() {
  // 
}
Enter fullscreen mode Exit fullscreen mode

LLMBuilder is implemented using a design pattern "Builder", which is a creational pattern that lets you construct complex objects steps-by-steps.

The end result is

let llm_builder = LLMBuilder::new()
        .tokenizer_repo_id("hf-internal-testing/llama-tokenizer")
        .model_repo_id("TheBloke/CodeLlama-7B-GGU")
        .model_file_name("codellama-7b.Q2_K.gguf");
let llm = llm_builder.build().await;
Enter fullscreen mode Exit fullscreen mode

Constructor ::default() v.s. ::new()

Rust does not have constructor for struct to assign values to fields in struct when creating new instances, it's common to use associated-functions ::new() for the same purpose. Another option is to use Default trait to implement "Constructor".

We implement Default trait for LLMBuilder and implement new() for the user who prefer ::new() pattern.

impl LLMBuilder {
    pub fn new() -> Self {
        Self::default()
    }
};
LLMBuilder::new(); // same as `LLMBuilder::default()`
Enter fullscreen mode Exit fullscreen mode

impl Into<String> for function parameter

To make the functions in our struct friendly for user, we use impl Into<String> tricks to allow passing both String and &str as function parameters.

impl LLMBuilder {
  pub fn tokenizer_repo_id(mut self, param: impl Into<String>) {}
}
// both are accepted
LLMBuilder::new().tokenizer_repo_id("string_slice");
LLMBuilder::new().tokenizer_repo_id(String::from("String"));
Enter fullscreen mode Exit fullscreen mode

Type State

In previous section, we implement the builder for LLM, that is great, we can construct LLM with the descriptive chain of methods.

let llm_builder = LLMBuilder::new()
    .tokenizer_repo_id("repo")
    .model_repo_id("repo")
    .model_file_name("file");
let llm = llm_builder.build().await;
Enter fullscreen mode Exit fullscreen mode

However, let's step aside, and be in the shoes the users, if a user tries to use LLMBuilder, it's possible that they forgot to support mandatory parameter, for example, one may forget to chain model_repo_id():

let llm_builder = LLMBuilder::new()
    .model_file_name("file");
Enter fullscreen mode Exit fullscreen mode

This is acceptable. unlike other language which designed to throws exceptions in runtime, Rust's Result will propagate error back to user. As a result, there won't be runtime exceptions if the user deal with the Result properly at compile time:

let llm_builder = LLMBuilder::new()
    .model_file_name("file");
let llm = match llm_builder.await {
    Ok(llm) => llm,
    Err(error) => {
        // handle the error properly here
    }
};
Enter fullscreen mode Exit fullscreen mode

However, what if we can improve the DX, to make the developer knows the problem as soon as possible, to make the feedback loop shorter, ideally when writing the code, i.e., compile time error?

Type State is a pattern that specify the state in type, and make compiler checks the state before running the code.

Our goal is to make compiler warn us, when mandatory parameters for creation of LLM is missing, for example, this will failed to compile:

let llm_builder = LLMBuilder::new();
let llm = lllm_builder.build().await;
Enter fullscreen mode Exit fullscreen mode

The compiler will tell the user that, hey, the build() can't be used yet, you should not pass!

and the code intelligent in the editor will support that, hey, there is tokenizer_repo_id() method available, would you want to try first?

We can help our user, to find the next steps by defining the type state

// Init state when `::new()` is called.
pub struct InitState;

// Intermediate state, with token repo id, ready to accept model repo id
pub struct WithTokenizerRepoId;
Enter fullscreen mode Exit fullscreen mode

And, move the implementation to where have the correct state, at the beginning, the state is InitState, and user can only use new() (does not change state), and tokenizer_repo_id(), which will return the instance with State=WithTokenizerRepoId.

impl LLMBuilder<InitState> {
    pub fn new() -> Self {
        LLMBuilder {
           ...
            // does not change state
            state: InitState,
        }
    }
    pub fn tokenizer_repo_id(
        self,
        tokenizer_repo_id: impl Into<String>,
    ) -> LLMBuilder<WithTokenizerRepoId> {
        LLMBuilder {
            ...
            // change state to `WithTokenizerRepoId`
            state: WithTokenizerRepoId,
        }
    }
}
Enter fullscreen mode Exit fullscreen mode

If we inspect the builder instance, we will notice that it has WithTokenizerRepoId state.

Screenshot 2023-11-02 at 20 58 57

That's great! Let's impl to builder with WithTokenizerRepoId state, so user will know what to do next.

With tokenizer_repo_id in place, the next is to set model_repo_id, calling model_repo_id() will set model_repo_id and return LLMBuilder<WithModelRepoId>

// Intermediate state, with model repo id, ready to accept model file name
pub struct WithModelRepoId;

impl LLMBuilder<WithTokenizerRepoId> {
    pub fn model_repo_id(self, model_repo_id: impl Into<String>) -> LLMBuilder<WithModelRepoId> {
        LLMBuilder {
            ...
           // change state to `WithModelRepoId`
            state: WithModelRepoId,
        }
    }
}
Enter fullscreen mode Exit fullscreen mode

We are almost ready, the final step is to assign model_file_name, then the builder is ready to build.

/// With both token repo id and model repo id
pub struct ReadyState;

impl LLMBuilder<WithModelRepoId> {
    pub fn model_file_name(self, model_file_name: impl Into<String>) -> LLMBuilder<ReadyState> {
        LLMBuilder {
            ...
           // change state to `WithModelRepoId`
            state: ReadyState,
        }
    }
}
Enter fullscreen mode Exit fullscreen mode

Implement the LLMBuilder<ReadyState> which adds the build() method.

impl LLMBuilder<ReadyState> {
    pub async fn build(self) -> Result<LLM> { ... }
}
Enter fullscreen mode Exit fullscreen mode

That's it. We have improved our builder. The compiler will emit errors when any of mandatory parameters are missing, and avoid the runtime exceptions.

The final result

let _ = LLMBuilder::new()
    // mandatory parameters, without these compiler warns
    .tokenizer_repo_id("string_slice")
    .model_repo_id("repo")
    .model_file_name("model.file");
Enter fullscreen mode Exit fullscreen mode

Inspect the builder, it has ReadyState!

Screenshot 2023-11-02 at 21 15 30

Share Memory Arc<Mutex<_>>

Share Memory by Communicating: Actor

Top comments (4)

Collapse
 
sandkumacode profile image
sandkumaCode

very well written thanks

Collapse
 
rdarrylr profile image
Darryl Ruggles

A very interesting project - thanks for sharing!

Collapse
 
tsumanu profile image
Adi

This project would be so great in actix, unfortunately you chosen half baked clunky and slow younger brother axum

Collapse
 
chenhunghan profile image
chh

You are welcome ˊ_>ˋ