Fossil Dataset Construction
A notebook demonstrating how to build a dataset using fastai and ipywidgets.
Introduction
I hacked together most of this code in between completing the first few lessons from the fastai course v3. The final product is a site I host internally so that my local paleontologist can label data for me. :)
This was the first big step in one of my projects. The objective of this notebook was to learn about the reddit API and build a dataset I could use to train a model using the resnet
architecture. It also helped me validate my home setup was complete.
What it didn't teach me is how to integrate with the broader community, which is another reason I wanted to create this blog.
#hide_output
import os
import sys
import re
import glob
from pathlib import Path
import numpy as np
import pandas as pd
import requests
# For generating the widgets
import ipywidgets as widgets
from ipywidgets import interact, interact_manual
from IPython.display import clear_output, display, Image as Img
from IPython.core.display import HTML
import ipyplot
from fastai.vision import *
from fastai.vision import core
# For interacting with reddit
import praw
from praw.models import MoreComments
# For managing our reddit secrets
from dotenv import load_dotenv
load_dotenv()
Data Organization
To get started, we will follow the example in the first lessons by using a dataset that is labeled by the directory name. We will store images in the path below, which we will also use for training with fastai.
In this case, images are saved to the path /home/tbeck/src/data/fossils
. fastai saves the images using an 8-digit, zero filled identifier. This code does not check for duplicates and does not allow the user to review the existing dataset (e.g. to clean it), although there are some tools now with fastai that might be useful for that purpose.
dataset_path = Path('/home/tbeck/src/data/fossils')
labels = [x.name for x in dataset_path.iterdir() if (x.is_dir() and x.name not in ['models'])]
reddit = praw.Reddit(client_id=os.environ['REDDIT_CLIENT_ID'], client_secret=os.environ['REDDIT_SECRET'],
password=os.environ['REDDIT_PASSWORD'], user_agent=os.environ['REDDIT_USER_AGENT'],
username=os.environ['REDDIT_USERNAME'])
fossilid = reddit.subreddit('fossilid')
We need some helper functions to retrieve the images and save them to our dataset. I found that URLs obtained from reddit need some post processing, otherwise they do not render properly.
def download_image(url, dest=None):
"""Given a URL, saves the image in a format the fastai likes."""
dest = Path(dest)
dest.mkdir(exist_ok=True)
files = glob.glob(os.path.join(dest, '*.jpg')) + glob.glob(os.path.join(dest, '*.png'))
i = len(files)
suffix = re.findall(r'\.\w+?(?=(?:\?|$))', url)
suffix = suffix[0] if len(suffix)>0 else '.jpg'
try: core.download_url(url, dest/f"{i:08d}{suffix}", overwrite=True, show_progress=False)
except Exception as e: f"Couldn't download {url}."
def get_image(url, verbose=False):
"""Given a URL, returns the URL if it looks like it's a URL to an image"""
IMG_TEST = "\.(jpg|png)"
p = re.compile(IMG_TEST, re.IGNORECASE)
if p.search(url):
if verbose:
print("url to image")
return url
IMGUR_LINK_TEST = r"((http|https)://imgur.com/[a-z0-9]+)$"
p = re.compile(IMGUR_LINK_TEST, re.IGNORECASE)
if p.search(url):
if verbose:
print("imgur without extension")
return url + '.jpg'
IMGUR_REGEX_TEST = r"((http|https)://i.imgur.com/[a-z0-9\.]+?(jpg|png))"
p = re.compile(IMGUR_REGEX_TEST, re.IGNORECASE)
if p.search(url):
if verbose:
print("imgur with extension")
return url
return None
class Error(Exception):
def __init__(self, msg):
self.msg = msg
class SubmissionStickiedError(Error):
pass
class SubmissionIsVideoError(Error):
pass
class SubmissionNotAnImageError(Error):
pass
class DisplayError(Error):
pass
Now we can query reddit for the images. The method below to build the dataset is a little clunky (I create arrays for each column of data and take great steps to be sure they are equal length). A better way would be to delegate creating this data structure to a single function so that the code below is less complex.
# Fetch submissions for analysis and initialize parallel arrays
submissions = []
images = []
top_comments = []
errors = []
verbose = False
for i, submission in enumerate(reddit.subreddit("fossilid").new(limit=None)):
submissions.append(submission)
images.append(None)
top_comments.append(None)
errors.append(None)
try:
if submission.stickied:
raise SubmissionStickiedError("Post is stickied")
if submission.is_video:
raise SubmissionIsVideoError("Post is a video")
if get_image(submission.url):
if verbose:
print(f"Title: {submission.title}")
images[i] = get_image(submission.url)
try:
if verbose:
display(Img(get_image(submission.url), retina=False, height=400, width=400))
except Exception as err:
if verbose:
print(f"Failed to retrieve transformed image url {get_image(submission.url)} from submission url {submission.url}")
raise DisplayError(f"Failed to retrieve transformed image url {get_image(submission.url)} from submission url {submission.url}")
submission.comments.replace_more(limit=None)
for top_level_comment in submission.comments:
if verbose:
print(f"Comment: \t{top_level_comment.body}")
top_comments[i] = top_level_comment.body
break
else:
raise SubmissionNotAnImageError("Post is not a recognized image url")
except Exception as err:
submissions[i] = None
images[i] = None
top_comments[i] = None
errors[i] = err.msg
df = pd.DataFrame({'submissions': submissions, 'images': images, 'comments': top_comments, 'errors': errors})
df.dropna(how='all', inplace=True)
df.dropna(subset=['images'], inplace=True)
debug = False
output2 = widgets.Output()
reset_button = widgets.Button(description='Reset')
def on_reset_button_click(_):
with output2:
clear_output()
int_range.value = 0
classes_dropdown.value = None
new_class.value = ''
reset_button.on_click(on_reset_button_click)
save_button = widgets.Button(
description='Save',
disabled=False,
button_style='', # 'success', 'info', 'warning', 'danger' or ''
tooltip='Save',
icon='check'
)
skip_button = widgets.Button(
description='Skip',
disabled=False,
button_style='', # 'success', 'info', 'warning', 'danger' or ''
tooltip='Skip',
icon=''
)
int_range = widgets.IntSlider(
value=0,
min=0,
max=len(df) - 1,
step=1,
description='Submission:',
disabled=False,
continuous_update=False,
orientation='horizontal',
readout=True,
readout_format='d'
)
img = widgets.Image(
value=requests.get(df.iloc[int_range.value]['images']).content,
format='png',
width=480,
height=640,
)
reddit_link = widgets.Label('Link: ' + str(df.iloc[int_range.value]['submissions'].url))
comment = widgets.Label('Hint: ' + str(df.iloc[int_range.value]['comments']))
local_options = [x.name for x in dataset_path.iterdir() if (x.is_dir() and x.name not in ['models'])]
local_options.sort()
classes_dropdown = widgets.Dropdown(
options=[None] + local_options,
value=None,
description='Class:',
disabled=False,
)
# Free form text widget
new_class = widgets.Text(
value=None,
placeholder='',
description='New Class:',
disabled=False
)
def on_save_button_click(_):
err=None
if len(new_class.value) > 0:
label = new_class.value
elif classes_dropdown.value:
label = classes_dropdown.value
else:
err = "You must specify a label to save to."
with output2:
clear_output()
if err:
print(err)
else:
if debug:
print(f"Would fetch index {int_range.value} from {df.iloc[int_range.value]['images']} to {dataset_path/label}")
else:
Path(Path(dataset_path)/Path(label)).mkdir(exist_ok=True)
core.download_url(f"{df.iloc[int_range.value]['images']}", f"{dataset_path/label}", show_progress=False, timeout=10)
local_options = [x.name for x in dataset_path.iterdir() if (x.is_dir() and x.name not in ['models'])]
local_options.sort()
classes_dropdown.options = [None] + local_options
int_range.value = int_range.value + 1
classes_dropdown.value = None
new_class.value = ''
def on_skip_button_click(_):
with output2:
clear_output()
int_range.value = int_range.value + 1
classes_dropdown.value = None
new_class.value = ''
save_button.on_click(on_save_button_click)
skip_button.on_click(on_skip_button_click)
def on_value_change(change):
img.value = requests.get(df.iloc[change['new']]['images']).content
reddit_link.value = 'Link: ' + str(df.iloc[int_range.value]['submissions'].url)
comment.value='Hint: ' + str(df.iloc[change['new']]['comments'])
#with output2:
# print(change['new'])
int_range.observe(on_value_change, names='value')
buttons = widgets.HBox([save_button, skip_button, reset_button])
entry = widgets.HBox([classes_dropdown, new_class])
things = widgets.VBox([int_range, img, reddit_link, comment, entry, buttons, output2])
display(things)
Widgets
To build the dataset and gain context, I created a custom widget for my local paleontologist to use. Here, she can easily navigate the images and apply labels. This widget was a composite of multiple ipywidgets to appear as a single form:
- A slider so they can quickly jump between submissions
- Drop down and text fields that can be populated with fixed or freeform data
- Buttons for saving data, advancing, and resetting the form.
In addition, I show the reddit comment using a label as a hint to the user.
When the widget is rendered, it uses the DataFrame to retrieve the image from the url.
The 'Class' dropdown is created from the labels loaded above. When 'New Class' is not empty and 'Save' is pressed, a new directory is created (if needed) and the image is saved there using fastai (see download_image()
).
Here is what the widget ends up looking like:
ipyplot.plot_images(df['images'], max_images=80, img_width=100)
Lessons Learned
Functionality here is somewhat similar to what other widgets can do, such as superintendent. But this was a good exercise to dust off my ipywidget skills, experiment with interacting with reddit's API, and building a custom dataset using fastai.
I built this notebook when fastai's course v3 was still out, when Google images was being scraped rather than the Bing search API being used.