SQL

SQL databases aren’t actually that scary

Today let’s dip our toes into the SQL world. We’ve spent a lot of time working with pandas and CSV files but data won’t always be packaged that way. It’s a valuable skill to be comfortable with data in all its forms.

We won’t spend too much time studying database queries—in fact there’s only one basic query in the code below. Instead we’ll walk through a quick example of retrieving and visualizing a table.


1. The data.

This post uses the chinook.db example database that’s found all over the internet. You can learn more about it here. The database contains 11 separate tables but we’ll just focus on one: invoices. In a more complex SQL query you might bring multiple tables together with a UNION or JOIN operation.

The invoices table contains 9 columns:

InvoiceId
CustomerId
InvoiceDate
BillingAddress
BillingCity
BillingState
BillingCountry
BillingPostalCode
Total

Let’s analyze how invoices vary between countries. We can find the number of invoices sent to each country and their average bill.

Begin with the imports. sqlite3 is built into the Python standard library so there’s no excuse not to take advantage! contextlib provides a convenient way to manage context when querying databases. The concept is similar to opening a text file by writing with open("file.txt", "r") as f. Whatever you place within the block only exists while the block is executed. The idea is to place a database connection in a context, then close it when you’re done.

import sqlite3
from contextlib import closing
import matplotlib.pyplot as plt
from matplotlib import ticker
import seaborn as sns

Before digging in let’s take a closer look at the table’s structure. There are a few different ways to retrieve column info but I find a PRAGMA operation to be the most readable. It also works independently from the cursor so you don’t have to retrieve a view before calling cursor.description.

Notice how contextlib.closing is used below. Both the cursor and the database connection only exist in the context of this code block.

with closing(sqlite3.connect("chinook.db")) as connection:
    with closing(connection.cursor()) as cursor:
        column_info = cursor.execute("PRAGMA table_info(invoices)").fetchall()
        for col in column_info:
            print(col[1])

At this point we’ve printed the names of all columns within the table. But notice we only printed the 1-indexed element of each item within column_info. That’s because elements include more information than just the column’s name. Below you can see all the information available:

(0, 'InvoiceId', 'INTEGER', 1, None, 1)
(1, 'CustomerId', 'INTEGER', 1, None, 0)
(2, 'InvoiceDate', 'DATETIME', 1, None, 0)
(3, 'BillingAddress', 'NVARCHAR(70)', 0, None, 0)
(4, 'BillingCity', 'NVARCHAR(40)', 0, None, 0)
(5, 'BillingState', 'NVARCHAR(40)', 0, None, 0)
(6, 'BillingCountry', 'NVARCHAR(40)', 0, None, 0)
(7, 'BillingPostalCode', 'NVARCHAR(10)', 0, None, 0)
(8, 'Total', 'NUMERIC(10,2)', 1, None, 0)

It’s often helpful to know a column’s datatype, for example.


2. Get the data from the database.

For this exercise we’ll focus on BillingCountry and Total columns. That’s all we need to calculate the number of invoices from each country and each country’s average bill. The query begins by selecting three columns from a particular table. For two of the columns it does an operation on them and assigns an alias.

The query groups by country, which works very similarly to pandas.groupby. All rows that represent a German buyer, for example, are combined into a single row. We do this because we’re interested in describing each country’s buyers as a group.

There are 24 unique countries within the table but we’ll limit the query to 10 for visualization purposes. That’s what LIMIT 10 means. The query specifies descending order to return the top-10 countries ranked by invoice count.

The fetchall method returns a list of tuples with as many elements as the query returns. Since we’re grouping by country and limiting the response to 10, rows will be a list of ten 3-tuples.

rows = cursor.execute(
    "SELECT BillingCountry, COUNT(BillingCountry) AS count, AVG(Total) AS avg "
    "FROM invoices "
    "GROUP BY BillingCountry "
    "ORDER BY count DESC "
    "LIMIT 10"
).fetchall()

The response:

[
 ('USA', 91, 5.747912087912091),
 ('Canada', 56, 5.427857142857142),
 ('France', 35, 5.574285714285712),
 ...
]

As you can see, USA is the most common customer country with Canada coming in second. USA customers spend slightly more per invoice than Canada and France.

Now that we’ve moved data from an SQL table to a Python list, the world is our oyster. We could turn it into a DataFrame, plot it right now, or anything else. Let’s plot it with Seaborn.


3. Prepare the data for Seaborn.

The cleanest way to unpack rows into 3 variables is to use zip along with the *  operator.

country, count, average = list(zip(*rows))

The only caveat is that these variables will be tuples, not lists, which Seaborn expects. However that’s easily addressed by using list() in the plotting methods later.

As an aside, if you aren’t familiar with the * operator, an asterisk is used to “unpack” iterables. Below is a more obvious demonstration:

>>> rows = [("John", 1, 10), ("Paul", 2, 20), ("George", 3, 30)]
>>> print(*rows)
('John', 1, 10) ('Paul', 2, 20) ('George', 3, 30)
>>> list(zip(*rows))
[('John', 'Paul', 'George'), (1, 2, 3), (10, 20, 30)]

So just like when you iterate through a zipped pair of lists…

for x, y in zip(x_values, y_values):
    print(x, y)

… The first elements go together, then the second elements, and so on. We use the same principle to separate rows into 3 iterables that we can plot.


4. Plot the data.

Now we’ll create 2 barplots to visualize the data. We’re going to place the plots side-by-side so create a 1×2 subplot grid and make the figure size extra wide. Specify an axis for each plot, tweak the style and formatting as desired, and we’re ready to go.

A few quick notes about the code:

  • wspace is useful when you have multiple plots on the same figure. It adjusts horizontal space between subplots.
  • Remember we’re converting the data to lists. It’s currently in tuple form.
  • You can iterate through subplots with a for loop. This might save you from repeating code.
  • I use matplotlib.ticker to format y-tick labels. You could manually create a list of strings and use set_yticklabels, but I think this is a cleaner approach.
sns.set(style="darkgrid", font="Ubuntu Condensed")
fig, axs = plt.subplots(1, 2, figsize=(16, 6))
fig.subplots_adjust(left=0.034, right=0.987, bottom=0.183, top=0.941, wspace=0.12)

sns.barplot(ax=axs[0], x=list(country), y=list(count), color="#9885bf", alpha=0.85)
sns.barplot(ax=axs[1], x=list(country), y=list(average), color="#76d94c", alpha=0.85)

axs[0].set_yticks(range(0, 120, 20))
axs[0].set_ylim(0, 102)
axs[1].set_yticks(range(8))
axs[1].set_ylim(0, 7.1)

for ax in axs:
    plt.setp(ax.xaxis.get_majorticklabels(), rotation=60, ha="right", rotation_mode="anchor", size=11)
    plt.setp(ax.yaxis.get_majorticklabels(), size=14)

axs[1].yaxis.set_major_formatter(ticker.StrMethodFormatter("${x:.2f}"))

axs[0].set_title("Number of Invoices", size=15)
axs[1].set_title("Average Invoice", size=15)

plt.show()

The output:

As you can see, USA customers are most common but they don’t spend as much, on average, as customers from India or the Czech Republic.


Although pandas has a convenient read_sql method, certain operations are simply easier when working with sqlite3 directly. I hope this post gives you a little more confidence the next time you’re face-to-face with a database.


Full Code:

import sqlite3
from contextlib import closing
import matplotlib.pyplot as plt
from matplotlib import ticker
import seaborn as sns


with closing(sqlite3.connect("chinook.db")) as connection:
    with closing(connection.cursor()) as cursor:
        rows = cursor.execute(
            "SELECT BillingCountry, COUNT(BillingCountry) AS count, AVG(Total) AS avg "
            "FROM invoices "
            "GROUP BY BillingCountry "
            "ORDER BY count DESC "
            "LIMIT 10"
        ).fetchall()

country, count, average = list(zip(*rows))

sns.set(style="darkgrid", font="Ubuntu Condensed")
fig, axs = plt.subplots(1, 2, figsize=(16, 6))
fig.subplots_adjust(left=0.034, right=0.987, bottom=0.183, top=0.941, wspace=0.12)

sns.barplot(ax=axs[0], x=list(country), y=list(count), color="#9885bf", alpha=0.85)
sns.barplot(ax=axs[1], x=list(country), y=list(average), color="#76d94c", alpha=0.85)

axs[0].set_yticks(range(0, 120, 20))
axs[0].set_ylim(0, 102)
axs[1].set_yticks(range(8))
axs[1].set_ylim(0, 7.1)

for ax in axs:
    plt.setp(ax.xaxis.get_majorticklabels(), rotation=60, ha="right", rotation_mode="anchor", size=11)
    plt.setp(ax.yaxis.get_majorticklabels(), size=14)

axs[1].yaxis.set_major_formatter(ticker.StrMethodFormatter("${x:.2f}"))

axs[0].set_title("Number of Invoices", size=15)
axs[1].set_title("Average Invoice", size=15)

plt.show()