Multimodal SearchĀ¶
In this final exercise, we will learn how to use vector databases to search through images using natural language.
We will be searching through an open source image dataset using an open source model called CLIP. This model is able to encode both images and text into the same embedding space, allowing us to retrieve images that are similar to a user question.
# pip install --quiet datasets gradio lancedb pandas transformers [This has been preinstalled for you]
from transformers import CLIPModel, CLIPProcessor
MODEL_ID = "openai/clip-vit-base-patch32"
device = "cpu"
model = CLIPModel.from_pretrained(MODEL_ID).to(device)
processor = CLIPProcessor.from_pretrained(MODEL_ID)
Setup data modelĀ¶
The dataset itself has an image field and an integer label. We'll also need an embedding vector (CLIP produces 512D vectors) field.
For this problem, please a field named "vector" to the Image class below that is a 512D vector.
The image that comes out of the raw dataset is a PIL image. So we'll add some conversion code between PIL and bytes to make it easier for serde.
import io
from lancedb.pydantic import LanceModel, vector
import PIL
class Image(LanceModel):
image: bytes
label: int
vector: vector(512)
def to_pil(self):
return PIL.Image.open(io.BytesIO(self.image))
@classmethod
def pil_to_bytes(cls, img) -> bytes:
buf = io.BytesIO()
img.save(buf, format="PNG")
return buf.getvalue()
Image processing functionĀ¶
Next we will implement a function to process batches of data from the dataset. We will be using the zh-plus/tiny-imagenet
dataset from huggingface datasets. This dataset has an image
and a label
column.
For this problem, please fill in the code to extract the image embeddings from the image using the CLIP model.
def process_image(batch: dict) -> dict:
image = processor(text=None, images=batch["image"], return_tensors="pt")[
"pixel_values"
].to(device)
# create the image embedding from the processed image and the model
img_emb = model.get_image_features(image)
batch["vector"] = img_emb.cpu()
batch["image_bytes"] = [Image.pil_to_bytes(img) for img in batch["image"]]
return batch
Table creationĀ¶
Please create a LanceDB table called image_search
to store the image, label, and vector.
import lancedb
TABLE_NAME = "image_search"
uri = "data/.lancedb/"
db = lancedb.connect(uri)
tbl = db.create_table(TABLE_NAME, schema=Image, exist_ok=True)
Adding dataĀ¶
Now we're ready to process the images and generate embeddings. Please write a function called datagen
that calls process_image
on each image in the validation set (10K images) and return a list of Image instances.
HINT
- You may find it faster to use the dataset.map function.
- You'll want to store the
image_bytes
field that is returned byprocess_image
.
from datasets import load_dataset
def datagen() -> list[Image]:
dataset = load_dataset("zh-plus/tiny-imagenet", split="valid")
batches = dataset.map(process_image, batched=True, batch_size=64)
# return Image instances
return [
Image(
image=batch["image_bytes"],
label=batch["label"],
vector=batch["vector"],
)
for batch in batches
]
return batches
Now call the function you just wrote and add the generated instances to the LanceDB table
data = datagen()
len(data)
Map: 0%| | 0/10000 [00:00<?, ? examples/s]
10000
len(data)
10000
table = db[TABLE_NAME]
table.add(data)
Encoding user queriesĀ¶
We have image embeddings, but how do we generate the embeddings for the user query? Furthermore, how can we possibly have the same features between the image embeddings and text embeddings. This is where the power of CLIP comes in.
Please write a function to turn user query text into an embedding in the same latent space as the images.
HINT You can refer to the CLIPModel documention
model
--------------------------------------------------------------------------- AttributeError Traceback (most recent call last) Cell In[23], line 1 ----> 1 model.vocab_size File ~/src/personal-notes/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1688, in Module.__getattr__(self, name) 1686 if name in modules: 1687 return modules[name] -> 1688 raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") AttributeError: 'CLIPModel' object has no attribute 'vocab_size'
from transformers import CLIPTokenizerFast
MODEL_ID = "openai/clip-vit-base-patch32"
model = CLIPModel.from_pretrained(MODEL_ID).to(device)
tokenizer = CLIPTokenizerFast.from_pretrained(MODEL_ID)
def embed_func(query):
inputs = tokenizer([query], padding=True, return_tensors="pt")
# generate the text embeddings
text_features = model.get_text_features(**inputs)
return text_features.detach().numpy()[0]
embed_func('a cat')
array([ 1.98062390e-01, -2.04020143e-01, -1.53303251e-01, -2.70534217e-01, -4.46111187e-02, 1.99026465e-01, -3.10602814e-01, -9.53776002e-01, -1.99635938e-01, 3.12783837e-01, 4.67424169e-02, -2.66986251e-01, 9.10230875e-02, -1.65169239e-01, 3.69340181e-02, 2.38453180e-01, 2.39864051e-01, -2.23458052e-01, -1.10424951e-01, 7.68509731e-02, 5.22796333e-01, 1.15722418e-04, 3.43306422e-01, 3.96145768e-02, -2.74524450e-01, 1.19890027e-01, 2.69382715e-01, 4.62291420e-01, -1.97028518e-01, -9.05410051e-02, 2.22785503e-01, -1.70280010e-01, 2.75245011e-01, 2.95357257e-01, 1.31358817e-01, -1.74410701e-01, 1.09354332e-01, 3.20174068e-01, 2.58988321e-01, 1.30152971e-01, -7.01289326e-02, -1.26815945e-01, 2.30268538e-01, 7.56177381e-02, 1.06794223e-01, 2.67533422e-01, -1.53526083e-01, -7.80682266e-02, -1.60718381e-01, -1.83804363e-01, 6.35028556e-02, -8.51078182e-02, 5.01584172e-01, -1.79669648e-01, -1.31239176e-01, -1.62459582e-01, 1.54039666e-01, 2.13683784e-01, -9.00811702e-02, -9.37883407e-02, 5.01728773e-01, -1.23423636e-02, -2.36510098e-01, 1.10763952e-01, -4.41495776e-02, -1.47802100e-01, 6.28673136e-02, 1.37415156e-01, -4.85866785e-01, 2.64178336e-01, 2.21243709e-01, -1.97863057e-01, 1.27805769e-01, -5.74396141e-02, -4.44428205e-01, 1.56616718e-02, 2.14221507e-01, 2.36592814e-01, -2.59426892e-01, 1.86081678e-02, -2.31517792e-01, -1.31049514e-01, -4.00871992e-01, 5.19718230e-01, 1.74635321e-01, 1.61400080e-01, 1.09090932e-01, -3.21196675e-01, 9.07432884e-02, 9.58712399e-02, 5.61524928e-02, -1.33946016e-01, -1.43692338e+00, 3.96428108e-01, 3.37460674e-02, 7.30052292e-02, 1.02266170e-01, -1.99785307e-01, 7.35486001e-02, 1.10444739e-01, 3.13075334e-02, -3.96005064e-03, 1.31326318e-01, 4.03267831e-01, -1.16588831e-01, 6.08048141e-02, 3.47355679e-02, 2.11048573e-01, -2.20009893e-01, 8.99751633e-02, 1.23045266e-01, -3.17420289e-02, -5.98033741e-02, 1.17122546e-01, -2.17836514e-01, 9.72964764e-02, -9.59487259e-02, 3.94764692e-02, 5.72783649e-02, -7.56147802e-02, 1.16543718e-01, -8.69374514e-01, -2.29135245e-01, -1.82635486e-01, 2.58566558e-01, -5.05051851e-01, -1.17933735e-01, -1.51030466e-01, 2.50940025e-01, 5.26677668e-01, -2.38317177e-02, -2.56337345e-01, -1.83694378e-01, 6.69688416e+00, -2.09272742e-01, -4.79371130e-01, -1.98161423e-01, -1.78132623e-01, -2.69790962e-02, -1.21743828e-01, 5.63305356e-02, -3.11673284e-02, -1.99032366e-01, 1.48137763e-01, -5.52112997e-01, 8.98647308e-03, 3.52554023e-03, -4.24117148e-01, -3.77347469e-01, 8.17058831e-02, 5.31655312e-01, -1.19602039e-01, -2.06967250e-01, 9.52405706e-02, -2.59117961e-01, 1.92964673e-02, 2.19347626e-01, -1.34474516e-01, -1.79303348e-01, -2.55983621e-02, -2.01949179e-02, 2.12860107e-01, 1.50255889e-01, -3.97108912e-01, 4.41028416e-01, 1.42786074e-02, 8.94952342e-02, -3.11465919e-01, -1.42291516e-01, -3.37849520e-02, -1.05680794e-01, 4.48076725e-02, -1.05548456e-01, -8.52510333e-04, -3.96975756e-01, 2.63510533e-02, -2.12739840e-01, -2.02222690e-02, 2.58148462e-01, 1.39008835e-01, 4.17587757e-01, 2.13188529e-01, -4.20816690e-01, 1.21761963e-01, 9.31622088e-03, 1.23665363e-01, 1.19099088e-01, -2.19894603e-01, -2.22985104e-01, -2.17397526e-01, -8.55646729e-02, -2.21458256e-01, -1.64846569e-01, 2.26939142e-01, -3.50898504e-01, 3.95645201e-02, -1.16591677e-01, -1.12183928e-01, -3.67802799e-01, -2.73298919e-01, -2.71035790e-01, -1.62019968e-01, 2.11805664e-02, 7.74955377e-02, 5.51230311e-02, 1.69755846e-01, -3.52188349e-01, -3.87969255e-01, 2.39716455e-01, -2.80844867e-01, 3.46530974e-01, 5.59549749e-01, 1.06462486e-01, -3.07622701e-02, 2.49327540e-01, 1.24282390e-03, 2.13814110e-01, -3.41966808e-01, 5.24843454e-01, -2.42821947e-01, -2.00987905e-01, -1.35232896e-01, 2.04596192e-01, -2.04908699e-01, 9.55732167e-02, 1.11678496e-01, 2.22284198e-01, 4.30583432e-02, -1.33768618e-01, -3.63171130e-01, -1.69195428e-01, 2.32669637e-01, 7.33527541e-02, 3.37277949e-01, 3.20158958e-01, 3.86033654e-01, 1.80266798e-01, 2.69135833e-03, -2.31719658e-01, -1.69526353e-01, 4.81469035e-02, 9.39136073e-02, 7.81280100e-02, 2.43768275e-01, 2.25621745e-01, -8.97380412e-02, -2.55520672e-01, -2.29940042e-01, -4.49811250e-01, -1.59180373e-01, -3.58362019e-01, 2.93906629e-01, 8.73334408e-02, 3.88257563e-01, 7.11230934e-02, 2.43521482e-01, -1.35361716e-01, 1.20982111e-01, -2.10438952e-01, 2.58543611e-01, -8.84040669e-02, 1.57285929e-02, -1.15565151e-01, -1.46303013e-01, -3.85754883e-01, -9.88684744e-02, -9.83188003e-02, -4.44616556e-01, 5.41732199e-02, 2.07058147e-01, -6.57697618e-02, -1.14243776e-01, 3.08545232e-01, -1.39827244e-02, -1.46738783e-01, -1.25710815e-02, 4.28772271e-02, 2.61132658e-01, -1.73749566e-01, -9.93890315e-02, 3.67282480e-01, 2.34841064e-01, 1.60515606e-01, 3.44907045e-01, 7.13964552e-02, 2.29937181e-01, -4.76410985e-02, 1.13447607e-01, -9.10493359e-02, 2.74371058e-02, 6.49679378e-02, 5.38732186e-02, 8.13178569e-02, 3.08812141e-01, -3.34173799e-01, -3.95899892e-01, -1.41880572e-01, -8.06304291e-02, -2.15945374e-02, 1.14565760e-01, 1.54656053e-01, 2.09523007e-01, -1.73452720e-01, 8.63644108e-03, -3.49615186e-01, -1.45563275e-01, 1.19597018e-02, 4.17828560e-05, -1.46210432e-01, 9.75944921e-02, 1.85582668e-01, -2.00997159e-01, 6.69285440e+00, 3.14259350e-01, 9.36956853e-02, 3.25383365e-01, -1.23260200e-01, 1.29006147e-01, 3.86894226e-01, -7.61240572e-02, 4.07357872e-01, 5.47243953e-01, -8.28170776e-02, 4.04916227e-01, 3.18420529e-01, -2.13257909e-01, 5.57489321e-02, -2.62708068e-01, 5.94567895e-01, -1.59084153e+00, 1.38205975e-01, 2.54127145e-01, -1.62465945e-01, 9.11469012e-02, 3.63405168e-01, -1.23121291e-02, 1.41928270e-01, 1.49168923e-01, 8.48018825e-02, 2.16556191e-01, 1.84963197e-01, -1.10550463e-01, 4.95045707e-02, -2.58212507e-01, 1.70007706e-01, -4.25086707e-01, 1.40853852e-01, 1.22077078e-01, -2.53204018e-01, -8.15843195e-02, 3.67850810e-02, -1.58872709e-01, 2.49910891e-01, -3.28396820e-02, 2.14339763e-01, 3.96103323e-01, 1.27734423e-01, 8.73683244e-02, 1.14427581e-01, 2.25478455e-01, 5.46164364e-02, -3.03968728e-01, 1.07197613e-02, 1.84723064e-01, 1.71030849e-01, -2.65028745e-01, 8.50801468e-02, -4.00393009e-02, 1.51245251e-01, 2.16862634e-02, 2.04281956e-02, 6.92915097e-02, -1.87580958e-01, 1.16396062e-01, -9.27587599e-03, 6.04869053e-02, -1.84819072e-01, -5.17234206e-02, -3.55510592e-01, 3.50914598e-02, -1.17076769e-01, 4.35459241e-02, -3.62249553e-01, 3.35055441e-02, 2.67982781e-01, -2.87433565e-01, 6.03650361e-02, 8.55568647e-02, -1.47962883e-01, -4.65264022e-02, -2.21959829e-01, -3.49343121e-02, -1.10954463e-01, -2.40594149e-03, 2.16212213e-01, 2.99852878e-01, 1.28945336e-01, -2.70365149e-01, -2.60317564e-01, 4.39590693e-01, 7.43417144e-02, -1.40688956e-01, 2.27384984e-01, 1.06539272e-01, -4.05541360e-01, 1.75807700e-01, 1.96940914e-01, -1.46521613e-01, 5.23817122e-01, 1.74833149e-01, 1.53445840e-01, 7.49015585e-02, -3.96235764e-01, 4.15975600e-02, -8.19831342e-03, 3.10609043e-02, -3.11419189e-01, -2.34995335e-01, 7.08097965e-03, -3.24762821e-01, -2.82619715e-01, 4.40817773e-02, -1.22182317e-01, -2.88110316e-01, 1.75714403e-01, 1.83276266e-01, -5.82954958e-02, 3.65192816e-02, 2.92147458e-01, 3.80962580e-01, 7.53624067e-02, -2.45229080e-02, 1.64169520e-01, 6.50589839e-02, 1.31181747e-01, -2.50878632e-01, -2.62874842e-01, 2.53630787e-01, -2.90259957e-01, 2.57635713e-02, -1.01876512e-01, 8.87164026e-02, -1.80807281e-02, 1.38700038e-01, 1.60354003e-01, -1.90474242e-02, -3.67421210e-01, -4.95343097e-02, 1.84618443e-01, -2.25359082e-01, -3.26037332e-02, 8.33832920e-02, -6.11295849e-02, 3.26348245e-02, -1.59113064e-01, -1.47277325e-01, 2.68414974e-01, 1.96620017e-01, 9.45597887e-04, -3.62225920e-01, 2.47467950e-01, -1.36334270e-01, 1.71475455e-01, 5.51524609e-02, 2.97862947e-01, -7.47831166e-02, 2.07558811e-01, -3.11181277e-01, 1.86948895e-01, -5.31928390e-02, 1.66152284e-01, -2.13561475e-01, 2.33834967e-01, -2.04534024e-01, -1.61807358e-01, 3.74217749e-01, -1.44060776e-01, -1.94033422e-02, 2.20237032e-01, 1.84689596e-01, -3.84974703e-02, -4.69996005e-01, -1.21067509e-01, 1.79716587e-01, -7.18396381e-02, -2.69601524e-01, 9.00799483e-02, 1.28800943e-02, -2.36719996e-02, 2.74337739e-01, -3.85800116e-02, -3.34011793e-01, 2.97316045e-01, 1.04631937e+00, 2.95067847e-01, 1.81233197e-01, 1.52869135e-01, 8.30803365e-02, 1.54877424e-01, 3.39331210e-01, 2.06089944e-01, 4.67044413e-02, 2.37961411e-01, 4.61485147e-01, -5.98512292e-02, -1.43343508e-01, -2.36426324e-01, -2.40798146e-01, 2.11418241e-01, -4.51412261e-01, -5.66431582e-01, 5.96289933e-02], dtype=float32)
Core search functionĀ¶
Now let's write the core search function find_images
, that takes a text query as input, and returns a list of PIL images that's most similar to the query.
def find_images(query):
# Generate the embedding for the query
emb = embed_func(query)
# Search for the closest 9 images
rs = table.search(emb).limit(9).to_pydantic(Image)
# Return PIL instances for visualization
return [image.to_pil() for image in rs]
find_images("fish")[0]
Create an AppĀ¶
Let's use gradio to create a small app to search through the images. The code below has been completed for you:
- Created a text input where the user can type in a query
- Created a "Submit" button that finds similar images to the input query and display the resulting images
- A Gallery component that displays the images
import gradio as gr
with gr.Blocks() as demo:
with gr.Row():
vector_query = gr.Textbox(value="fish", show_label=False)
b1 = gr.Button("Submit")
with gr.Row():
gallery = gr.Gallery(
label="Found images", show_label=False, elem_id="gallery"
).style(columns=[3], rows=[3], object_fit="contain", height="auto")
b1.click(find_images, inputs=vector_query, outputs=gallery)
demo.launch(server_name="0.0.0.0", inline=False)
/tmp/ipykernel_5321/1501538491.py:9: GradioDeprecationWarning: The `style` method is deprecated. Please set these arguments in the constructor instead. gallery = gr.Gallery(
Running on local URL: http://0.0.0.0:7860 To create a public link, set `share=True` in `launch()`.
To view the interface, click on the Links button at the bottom of the workspace window. Then click on gradio. This will open a new browser window with the interface.
Now try a bunch of different queries and see the results. By default CLIP search results leave a lot of room for improvement. More advanced applications in this space can improve these results in a number ways like retraining the model with your own dataset, your own labels, and using image and text vectors to train the index. The details are however beyond the scope of this lesson.
SummaryĀ¶
Congrats!
Through this exercise, you learned how to use CLIP to generate image and text embeddings. You've mastered how to use vector databases to enable searching through images using natural language. And you even created a simple app to show off your work.
Great job!
Created: 2024-10-23