Building an image retrieval app with Supabase
5 min read ·
In this article, I will show you how to build a generic foundation for an image retrieval app that can later be fine-tuned to a specific use case.
I will guide you through creating a Python backend with machine learning logic.
We will be using the Supabase vecs
Python library and
Supabase Storage. Afterward, we will integrate
it into a frontend application.
I’ll be using the following tech stack:
- Python: all the image retrieval logic will be implemented on the backend.
- FastAPI: a Python web framework for building APIs. We’ll use it to create a minimal REST server that we’ll later integrate with to perform the search.
- Supabase’s
vecs
: a Python client for managing and querying vector stores. We’ll use it to query similar images. - Supabase Storage: used to store and retrieve images.
- Hugging Face’s Model hub: think of it as an NPM but for pre-trained models. We’ll obtain the model from there.
- timm: a deep-learning library for image classification.
Let’s get started!
🔗 If you want to dive straight into the code, here’s the repository.
Setting Up Supabase
- Start by navigating to the Supabase Dashboard and click on the “New Project” button to create a new project.
- Once you have successfully created the project, proceed to set up new storage. We’ll use it to store images. Don’t forget to upload a few!
Building the Python Backend
Prerequisites
Python 3.9
pip
poetry
for packaging and dependency management
Install poetry
using pip
:
sh
pip install poetry
sh
pip install poetry
Initialize a new project:
sh
poetry new image-retrieval
sh
poetry new image-retrieval
Adding Dependecies
We need to add the following dependencies to our project:
- vecs: Supabase Vector Python Client.
- matplotlib: for displaying image results.
- timm: a library for deep learning-based image classification.
- storage3: Python Client library for interacting with Supabase Storage.
- dotenv: for loading environment variables.
sh
poetry add vecs matplotlib timm storage3 dotenv
sh
poetry add vecs matplotlib timm storage3 dotenv
Processing Existing Images
To start, our first task is to generate embeddings for our existing images. Embeddings are arrays of numbers (vectors) that represent data, specifically images in our case.
This step assumes that you’ve already uploaded some images to the Supabase storage.
Let’s walk through the process of configuring the necessary settings, loading a pre-trained model, and defining a seeding function to process and store the image embeddings.
Create a main.py
file in the image-retrieval
folder. We are going to add a
few things:
- Imports
py
import vecsimport osimport timmimport PIL.Imagefrom storage3 import create_clientimport requestsfrom dotenv import load_dotenv
py
import vecsimport osimport timmimport PIL.Imagefrom storage3 import create_clientimport requestsfrom dotenv import load_dotenv
- Environmental variables
You can obtain the Supabase keys and database URL from the Supabase dashboard.
py
load_dotenv()DB_CONNECTION = os.environ.get("DB_URL", "postgresql://postgres:postgres@localhost:54322/postgres")SUPABASE_ID = os.environ.get("SUPABASE_ID", "")SUPABASE_KEY = os.environ.get("SUPABASE_KEY", "")
py
load_dotenv()DB_CONNECTION = os.environ.get("DB_URL", "postgresql://postgres:postgres@localhost:54322/postgres")SUPABASE_ID = os.environ.get("SUPABASE_ID", "")SUPABASE_KEY = os.environ.get("SUPABASE_KEY", "")
- Storage configuration
This sets up the Supabase storage URL and initializes the storage client with the necessary authentication headers.
py
STORAGE_URL = f"https://{SUPABASE_ID}.supabase.co/storage/v1"headers = {"apiKey": SUPABASE_KEY, "Authorization": f"Bearer {SUPABASE_KEY}"}storage_client = create_client(STORAGE_URL, headers, is_async=False)
py
STORAGE_URL = f"https://{SUPABASE_ID}.supabase.co/storage/v1"headers = {"apiKey": SUPABASE_KEY, "Authorization": f"Bearer {SUPABASE_KEY}"}storage_client = create_client(STORAGE_URL, headers, is_async=False)
- Vector database setup
This step connects to the vector database and initializes or retrieves a collection for storing image vectors.
py
vx = vecs.create_client(DB_CONNECTION)images = vx.get_or_create_collection(name=COLLECTION_NAME, dimension=1536)
py
vx = vecs.create_client(DB_CONNECTION)images = vx.get_or_create_collection(name=COLLECTION_NAME, dimension=1536)
- Model Initialization
It loads a pre-trained Inception ResNet V2 model and sets it to evaluation mode.
py
model = timm.create_model('inception_resnet_v2.tf_in1k',pretrained=True,num_classes=0,)model = model.eval()
py
model = timm.create_model('inception_resnet_v2.tf_in1k',pretrained=True,num_classes=0,)model = model.eval()
- Seeding function
py
def seed():data_config = timm.data.resolve_model_data_config(model)transforms = timm.data.create_transform(**data_config, is_training=True)files = storage_client.from_("images").list()for file in files:image_url = storage_client.from_("images").get_public_url(file.get("name"))response = requests.get(image_url)image = PIL.Image.open(io.BytesIO(response.content)).convert("RGB")output = model(transforms(image).unsqueeze(0))img_emb = output[0].detach().numpy()images.upsert([(file.get("name"),img_emb,{"url": image_url})])print("Upserted images")images.create_index()print("Created index")
py
def seed():data_config = timm.data.resolve_model_data_config(model)transforms = timm.data.create_transform(**data_config, is_training=True)files = storage_client.from_("images").list()for file in files:image_url = storage_client.from_("images").get_public_url(file.get("name"))response = requests.get(image_url)image = PIL.Image.open(io.BytesIO(response.content)).convert("RGB")output = model(transforms(image).unsqueeze(0))img_emb = output[0].detach().numpy()images.upsert([(file.get("name"),img_emb,{"url": image_url})])print("Upserted images")images.create_index()print("Created index")
Now we have a little bit more going on. The seed
function does the following:
- Sets up image transformations.
- Lists all files in the Supabase “images” bucket.
- For each file, fetches the image, processes it through the model, and stores its vector in the database.
- After processing all images, creates an index for faster searches.
Now that this is done, we can add it as a script to pyproject.toml
:
toml
[tool.poetry.scripts]seed = "image_retrieval.main:seed"
toml
[tool.poetry.scripts]seed = "image_retrieval.main:seed"
And then we can execute it:
sh
poetry run seed
sh
poetry run seed
Once the script is complete, you can navigate to your Supabase dashboard to view the newly created entries.
Searching Images by Image
With Supabase’s vecs
library, querying our embeddings becomes straightforward.
We’ll use an image as our search input and retrieve similar images from our
database. The provided code guides you through this process, from setting up the
necessary configurations to defining a function that returns similar images.
To visualize the results, we’ll use matplotlib
.
We will add a new method in the main.py
file.
py
def get_results(image):data_config = timm.data.resolve_model_data_config(model)transforms = timm.data.create_transform(**data_config, is_training=False)output = model(transforms(image.convert("RGB")).unsqueeze(0))img_emb = output[0].detach().numpy()results = images.query(img_emb,include_value = True,include_metadata = True,)return results
py
def get_results(image):data_config = timm.data.resolve_model_data_config(model)transforms = timm.data.create_transform(**data_config, is_training=False)output = model(transforms(image.convert("RGB")).unsqueeze(0))img_emb = output[0].detach().numpy()results = images.query(img_emb,include_value = True,include_metadata = True,)return results
Here’s what the function does:
- Setting up image transformations.
is_training=False
indicates that these transformations are for inference mode, not training. - The input image is converted to RGB format and then transformed. Then, it is passed through the model to obtain its embedding.
- The
images.query
method searches the database for images with similar embeddings to the input image’s embedding.
Now that we have the results, we can use matplotlib
to visualize them.
def plot_results():image = PIL.Image.open("./images/test-image.jpg")results = get_results(image)columns = 3rows = results.__len__() // columns + 1fig = plt.figure(figsize=(10, 10))for i in range(0, results.__len__()):_, score, metadata = results[i]if (score > 0.5):continueresponse = requests.get(metadata.get("url"))image = PIL.Image.open(io.BytesIO(response.content))fig.add_subplot(rows, columns, i + 1)plt.imshow(image)plt.axis('off')plt.title(f"Similarity: {1 - score:.2f}")plt.show()
def plot_results():image = PIL.Image.open("./images/test-image.jpg")results = get_results(image)columns = 3rows = results.__len__() // columns + 1fig = plt.figure(figsize=(10, 10))for i in range(0, results.__len__()):_, score, metadata = results[i]if (score > 0.5):continueresponse = requests.get(metadata.get("url"))image = PIL.Image.open(io.BytesIO(response.content))fig.add_subplot(rows, columns, i + 1)plt.imshow(image)plt.axis('off')plt.title(f"Similarity: {1 - score:.2f}")plt.show()
The plot_results
function does the following:
- Loads a test image from a specified path.
- Finds similar images from the database.
- Sets up a display grid with 3 columns.
- For each similar image with a similarity score above
0.5
:- Fetches the image from its URL.
- Adds it to the display grid with a similarity score label.
- Displays the grid of images.
🚨 Note that I have added a test image to an images/
folder in the root of the
repository!
Similar to before, we will add plot_result
as a new script to the
pyproject.toml
file so that it can be executed with poetry.
toml
search = "image_retrieval.main:plot_results"
toml
search = "image_retrieval.main:plot_results"
Then, if we run poetry run search
, it should display the following screen
based on the input images you provided.
Creating a REST API
Now that we have the necessary retrieval logic, we will expose an endpoint that we’ll then integrate with from the frontend. There are a few ways to create REST services in Python, such as Django, Flask, FastAPI. In this guide, I’ll choose FastAPI.
Let’s add it as a dependency:
sh
poetry add fastapi pydantic
sh
poetry add fastapi pydantic
We’ll also install Pydantic, which is a data validation library for Python.
There are a few things we need to add to the main.py
file.
- FastAPI application setup
py
app = FastAPI()
py
app = FastAPI()
- A new data model
py
class Image(BaseModel):name: strscore: floaturl: str
py
class Image(BaseModel):name: strscore: floaturl: str
The Image
class represents an endpoint return type.
- Endpoint definition:
py
@app.post("/search/")
py
@app.post("/search/")
It establishes a POST endpoint at /search/
.
- Search function
py
def search(file: UploadFile) -> list[Image]:contents = file.file.read()image = PIL.Image.open(io.BytesIO(contents))results = get_results(image)images = []for name, score, metadata in results:images.append(Image(name=name, score=score, url=metadata.get("url")))return images
py
def search(file: UploadFile) -> list[Image]:contents = file.file.read()image = PIL.Image.open(io.BytesIO(contents))results = get_results(image)images = []for name, score, metadata in results:images.append(Image(name=name, score=score, url=metadata.get("url")))return images
It handles incoming POST requests. It expects a file as a request input.
You can start the API with:
sh
poetry run uvicorn image-retrieval.main:app --reload
sh
poetry run uvicorn image-retrieval.main:app --reload
Frontend integration
And we’re almost there 🎉 The backend part is completed, and now we only need to integrate it.
The provided code snippet demonstrates how to send an image to our backend and retrieve similar images. For a complete app example, please check out the repository.
ts
const getResults = async (file: File) => {const formData = new FormData();formData.append("file", file);try {const response = await fetch("http://localhost:8000/search/", {method: "POST",body: formData,headers: {accept: "application/json",},});const files = await response.json();if (Array.isArray(files)) {return files.filter((file) => file.score < 0.5);}return [];} catch (err) {console.log(err);}};
ts
const getResults = async (file: File) => {const formData = new FormData();formData.append("file", file);try {const response = await fetch("http://localhost:8000/search/", {method: "POST",body: formData,headers: {accept: "application/json",},});const files = await response.json();if (Array.isArray(files)) {return files.filter((file) => file.score < 0.5);}return [];} catch (err) {console.log(err);}};
Here’s a little demo:
Summary
Now that we have the foundation for an image retrieval app, the next step to transform the current app into a less generic one is to use a specific dataset to train the model. As someone who is new to a lot of machine learning concepts, I am personally excited to learn how to train models and I hope to write about it soon!
To learn more about the tools used in the guide, check out the following links: