FLUX (by Black Forest Labs) has taken the world of AI image generation by storm in the last few months. Not only has it beat Stable Diffusion (the prior open-source king) on many benchmarks, it has also surpassed proprietary models like Dall-E or Midjourney in some metrics.
But how would you go about using FLUX on one of your apps? One might think of using serverless hosts like Replicate and others, but these can get very expensive very quickly, and may not provide the flexibility you need. That's where creating your own custom FLUX server comes in handy.
In this article, we'll walk you through creating your own FLUX server using Python. This server will allow you to generate images based on text prompts via a simple API. Whether you're running this server for personal use or deploying it as part of a production application, this guide will help you get started.
Prerequisites
Before diving into the code, let's ensure you have the necessary tools and libraries set up:
- Python: You'll need Python 3 installed on your machine, preferably version 3.10.
-
torch
: The deep learning framework we'll use to run FLUX. -
diffusers
: Provides access to the FLUX model. -
transformers
: Required dependency of diffusers. -
sentencepiece
: Required to run the FLUX tokenizer -
protobuf
: Required to run FLUX -
accelerate
: Helps load the FLUX model more efficiently in some cases. -
fastapi
: Framework to create a web server that can accept image generation requests. -
uvicorn
: Required to run the FastAPI server. -
psutil
: Allows us to check how much RAM there is on our machine.
You can install all the libraries by running the following command: pip install torch diffusers transformers sentencepiece protobuf accelerate fastapi uvicorn
.
If you're using a Mac with an M1 or M2 chip, you should set up PyTorch with Metal for optimal performance. Follow the official PyTorch with Metal guide before proceeding.
You'll also need to make sure you have at least 12 GB of VRAM if you're planning on running FLUX on a GPU device. Or at least 12 GB of RAM for running on CPU/MPS (which will be slower).
Step 1: Setting Up the Environment
Let's start the script by picking the right device to run inference based on the hardware we're using.
device = 'cuda' # can also be 'cpu' or 'mps'
import os
# MPS support in PyTorch is not yet fully implemented
if device == 'mps':
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import torch
if device == 'mps' and not torch.backends.mps.is_available():
raise Exception("Device set to MPS, but MPS is not available")
elif device == 'cuda' and not torch.cuda.is_available():
raise Exception("Device set to CUDA, but CUDA is not available")
You can specify cpu
, cuda
(for NVIDIA GPUs), or mps
(for Apple's Metal Performance Shaders). The script then checks if the selected device is available and raises an exception if it's not.
Step 2: Loading the FLUX Model
Next, we load the FLUX model. We'll load the model in fp16 precision which will save us some memory without much loss in quality.
At this point, you may be asked to authenticate with HuggingFace, as the FLUX model is gated. In order to authenticate successfully, you'll need to create a HuggingFace account, go to the model page, accept the terms, and then create a HuggingFace token from your account settings and add it on your machine as the
HF_TOKEN
environment variable.
from diffusers import FlowMatchEulerDiscreteScheduler, FluxPipeline
import psutil
model_name = "black-forest-labs/FLUX.1-dev"
print(f"Loading {model_name} on {device}")
pipeline = FluxPipeline.from_pretrained(
model_name,
# Diffusion models are generally trained on fp32, but fp16
# gets us 99% there in terms of quality, with just half the (V)RAM
torch_dtype=torch.float16,
# Ensure we don't load any dangerous binary code
use_safetensors=True
# We are using Euler here, but you can also use other samplers
scheduler=FlowMatchEulerDiscreteScheduler()
).to(device)
Here, we're loading the FLUX model using the diffusers library. The model we're using is black-forest-labs/FLUX.1-dev
, loaded in fp16 precision.
There is also a timestep-distilled model named FLUX Schnell which has faster inference, but outputs less detailed images, as well as a FLUX Pro model which is closed-source.
We'll use the Euler scheduler here, but you may experiment with this. You can read more on schedulers here.
Since image generation can be resource-intensive, it's crucial to optimize memory usage, especially when running on a CPU or a device with limited memory.
# Recommended if running on MPS or CPU with < 64 GB of RAM
total_memory = psutil.virtual_memory().total
total_memory_gb = total_memory / (1024 ** 3)
if (device == 'cpu' or device == 'mps') and total_memory_gb < 64:
print("Enabling attention slicing")
pipeline.enable_attention_slicing()
This code checks the total available memory and enables attention slicing if the system has less than 64 GB of RAM. Attention slicing reduces memory usage during image generation, which is essential for devices with limited resources.
Step 3: Creating the API with FastAPI
Next, we'll set up the FastAPI server, which will provide an API to generate images.
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field, conint, confloat
from fastapi.middleware.gzip import GZipMiddleware
from io import BytesIO
import base64
app = FastAPI()
# We will be returning the image as a base64 encoded string
# which we will want compressed
app.add_middleware(GZipMiddleware, minimum_size=1000, compresslevel=7)
FastAPI is a popular framework for building web APIs with Python. In this case, we're using it to create a server that can accept requests for image generation. We're also using GZip middleware to compress the response, which is particularly useful when sending images back in base64 format.
In a production environment, you might want to store the generated images in an S3 bucket or other cloud storage and return the URLs instead of the base64-encoded strings, to take advantage of a CDN and other optimizations.
Step 4: Defining the Request Model
We now need to define a model for the requests that our API will accept.
class GenerateRequest(BaseModel):
prompt: str
seed: conint(ge=0) = Field(..., description="Seed for random number generation")
height: conint(gt=0) = Field(..., description="Height of the generated image, must be a positive integer and a multiple of 8")
width: conint(gt=0) = Field(..., description="Width of the generated image, must be a positive integer and a multiple of 8")
cfg: confloat(gt=0) = Field(..., description="CFG (classifier-free guidance scale), must be a positive integer or 0")
steps: conint(ge=0) = Field(..., description="Number of steps")
batch_size: conint(gt=0) = Field(..., description="Number of images to generate in a batch")
This GenerateRequest
model defines the parameters required to generate an image. The prompt
field is the text description of the image you want to create. Other fields include the image dimensions, the number of inference steps, and the batch size.
Step 5: Creating the Image Generation Endpoint
Now, let's create the endpoint that will handle image generation requests.
@app.post("/")
async def generate_image(request: GenerateRequest):
# Validate that height and width are multiples of 8
# as required by FLUX
if request.height % 8 != 0 or request.width % 8 != 0:
raise HTTPException(status_code=400, detail="Height and width must both be multiples of 8")
# Always calculate the seed on CPU for deterministic RNG
# For a batch of images, seeds will be sequential like n, n+1, n+2, ...
generator = [torch.Generator(device="cpu").manual_seed(i) for i in range(request.seed, request.seed + request.batch_size)]
images = pipeline(
height=request.height,
width=request.width,
prompt=request.prompt,
generator=generator,
num_inference_steps=request.steps,
guidance_scale=request.cfg,
num_images_per_prompt=request.batch_size
).images
# Convert images to base64 strings
# (for a production app, you might want to store the
# images in an S3 bucket and return the URLs instead)
base64_images = []
for image in images:
buffered = BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
base64_images.append(img_str)
return {
"images": base64_images,
}
This endpoint handles the image generation process. It first validates that the height and width are multiples of 8, as required by FLUX. It then generates images based on the provided prompt and returns them as base64-encoded strings.
Step 6: Starting the Server
Finally, let's add some code to start the server when the script is run.
@app.on_event("startup")
async def startup_event():
print("Image generation server running")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
This code starts the FastAPI server on port 8000, making it accessible not only from http://localhost:8000
but also from other devices on the same network using the host machine’s IP address, thanks to the 0.0.0.0
binding.
Step 7: Testing Your Server Locally
Now that your FLUX server is up and running, it's time to test it. You can use curl
, a command-line tool for making HTTP requests, to interact with your server:
curl -X POST "http://localhost:8000/" \
-H "Content-Type: application/json" \
-d '{
"prompt": "A futuristic cityscape at sunset",
"seed": 42,
"height": 1024,
"width": 1024,
"cfg": 3.5,
"steps": 50,
"batch_size": 1
}' | jq -r '.images[0]' | base64 -d > test.png
This command will only work on UNIX-based systems with the
curl
,jq
andbase64
utilities installed. It may also take up to a few minutes to complete depending on the hardware hosting the FLUX server.
Conclusion
Congratulations! You've successfully created your own FLUX server using Python. This setup allows you to generate images based on text prompts via a simple API. If you're not satisfied with the results of the base FLUX model, you might consider fine-tuning the model for even better performance on specific use cases.
Full code
You may find the full code used in this guide below:
device = 'cuda' # can also be 'cpu' or 'mps'
import os
# MPS support in PyTorch is not yet fully implemented
if device == 'mps':
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import torch
if device == 'mps' and not torch.backends.mps.is_available():
raise Exception("Device set to MPS, but MPS is not available")
elif device == 'cuda' and not torch.cuda.is_available():
raise Exception("Device set to CUDA, but CUDA is not available")
from diffusers import FlowMatchEulerDiscreteScheduler, FluxPipeline
import psutil
model_name = "black-forest-labs/FLUX.1-dev"
print(f"Loading {model_name} on {device}")
pipeline = FluxPipeline.from_pretrained(
model_name,
# Diffusion models are generally trained on fp32, but fp16
# gets us 99% there in terms of quality, with just half the (V)RAM
torch_dtype=torch.float16,
# Ensure we don't load any dangerous binary code
use_safetensors=True,
# We are using Euler here, but you can also use other samplers
scheduler=FlowMatchEulerDiscreteScheduler()
).to(device)
# Recommended if running on MPS or CPU with < 64 GB of RAM
total_memory = psutil.virtual_memory().total
total_memory_gb = total_memory / (1024 ** 3)
if (device == 'cpu' or device == 'mps') and total_memory_gb < 64:
print("Enabling attention slicing")
pipeline.enable_attention_slicing()
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field, conint, confloat
from fastapi.middleware.gzip import GZipMiddleware
from io import BytesIO
import base64
app = FastAPI()
# We will be returning the image as a base64 encoded string
# which we will want compressed
app.add_middleware(GZipMiddleware, minimum_size=1000, compresslevel=7)
class GenerateRequest(BaseModel):
prompt: str
seed: conint(ge=0) = Field(..., description="Seed for random number generation")
height: conint(gt=0) = Field(..., description="Height of the generated image, must be a positive integer and a multiple of 8")
width: conint(gt=0) = Field(..., description="Width of the generated image, must be a positive integer and a multiple of 8")
cfg: confloat(gt=0) = Field(..., description="CFG (classifier-free guidance scale), must be a positive integer or 0")
steps: conint(ge=0) = Field(..., description="Number of steps")
batch_size: conint(gt=0) = Field(..., description="Number of images to generate in a batch")
@app.post("/")
async def generate_image(request: GenerateRequest):
# Validate that height and width are multiples of 8
# as required by FLUX
if request.height % 8 != 0 or request.width % 8 != 0:
raise HTTPException(status_code=400, detail="Height and width must both be multiples of 8")
# Always calculate the seed on CPU for deterministic RNG
# For a batch of images, seeds will be sequential like n, n+1, n+2, ...
generator = [torch.Generator(device="cpu").manual_seed(i) for i in range(request.seed, request.seed + request.batch_size)]
images = pipeline(
height=request.height,
width=request.width,
prompt=request.prompt,
generator=generator,
num_inference_steps=request.steps,
guidance_scale=request.cfg,
num_images_per_prompt=request.batch_size
).images
# Convert images to base64 strings
# (for a production app, you might want to store the
# images in an S3 bucket and return the URL's instead)
base64_images = []
for image in images:
buffered = BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
base64_images.append(img_str)
return {
"images": base64_images,
}
@app.on_event("startup")
async def startup_event():
print("Image generation server running")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
Top comments (2)
@Komninos, very interesting read, thanks for sharing!
Thank you!