Monday, November 16, 2015

Python for Data Analysis Part 20: Plotting with Pandas


* Edit Jan 2021: I recently completed a YouTube video covering topics in this post:




Visualizations are one of the most powerful tools at your disposal for exploring data and communicating data insights. The pandas library includes basic plotting capabilities that let you create a variety of plots from DataFrames. Plots in pandas are built on top of a popular Python plotting library called matplotlib, which comes with the Anaconda Python distribution.
Let's start by loading some packages:
In [2]:
import numpy as np
import pandas as pd
import matplotlib
from ggplot import diamonds

matplotlib.style.use('ggplot')       # Use ggplot style plots*
*Note: If you have not installed ggplot, you can do so by opening a console and running "pip install ggplot" (without quotes.).
In this lesson, we're going to look at the diamonds data set that comes with the ggplot library. Let's take a moment to explore the structure of the data before going any further:
In [3]:
diamonds.shape        # Check data shape
Out[3]:
(53940, 10)
In [4]:
diamonds.head(5)
Out[4]:
caratcutcolorclaritydepthtablepricexyz
00.23IdealESI261.5553263.953.982.43
10.21PremiumESI159.8613263.893.842.31
20.23GoodEVS156.9653274.054.072.31
30.29PremiumIVS262.4583344.204.232.63
40.31GoodJSI263.3583354.344.352.75
The output shows that data set contains 10 features of 53940 different diamonds, including both numeric and categorical variables.

Histograms

A histogram is a univariate plot (a plot that displays one variable) that groups a numeric variable into bins and displays the number of observations that fall within each bin. A histogram is a useful tool for getting a sense of the distribution of a numeric variable. Let's create a histogram of diamond carat weight with the df.hist() function:
In [5]:
diamonds.hist(column="carat",        # Column to plot
              figsize=(8,8),         # Plot size
              color="blue")          # Plot color
Out[5]:
array([[<matplotlib.axes._subplots.AxesSubplot object at 0x000000000B25E2E8>]], dtype=object)
We see immediately that the carat weights are positively skewed: most diamonds are around 1 carat or below but there are extreme cases of larger diamonds.
The plot above has fairly wide bins and there doesn't appear to be any data beyond a carat size of 3.5. We can make try to get more out of hour histogram by adding some additional arguments to control the size of the bins and limits of the x-axis:
In [6]:
diamonds.hist(column="carat",        # Column to plot
              figsize=(8,8),         # Plot size
              color="blue",          # Plot color
              bins=50,               # Use 50 bins
              range= (0,3.5))        # Limit x-axis range
Out[6]:
array([[<matplotlib.axes._subplots.AxesSubplot object at 0x000000000B499EF0>]], dtype=object)
This histogram gives us a better sense of some subtleties within the distribution, but we can't be sure that it contains all the data. Limiting the X-axis to 3.5 might have cut out some outliers with counts so small that they didn't show up as bars on our original chart. Let's check to see if any diamonds are larger than 3.5 carats:
In [7]:
diamonds[diamonds["carat"] > 3.5]
Out[7]:
caratcutcolorclaritydepthtablepricexyz
236443.65FairHI167.153116689.539.486.38
259984.01PremiumII161.0611522310.1410.106.17
259994.01PremiumJI162.5621522310.029.946.24
264444.00Very GoodII163.3581598410.019.946.31
265343.67PremiumII162.456161939.869.816.13
271304.13FairHI164.8611732910.009.856.43
274155.01FairJI165.5591801810.7410.546.98
276304.50FairJI165.8581853110.2310.166.72
276793.51PremiumJVS262.559187019.669.636.03
It turns out that 9 diamonds are bigger than 3.5 carats. Should cutting these diamonds out concern us? On one hand, these outliers have very little bearing on the shape of the distribution. On the other hand, limiting the X-axis to 3.5 implies that no data lies beyond that point. For our own exploratory purposes this is not an issue but if we were to show this plot to someone else, it could be misleading. Including a note that 9 diamonds lie beyond the chart range could be helpful.

Boxplots

Boxplots are another type of univariate plot for summarizing distributions of numeric data graphically. Let's make a boxplot of carat using the pd.boxplot() function:
In [8]:
diamonds.boxplot(column="carat")
As we learned in lesson 14, the central box of the boxplot represents the middle 50% of the observations, the central bar is the median and the bars at the end of the dotted lines (whiskers) encapsulate the great majority of the observations. Circles that lie beyond the end of the whiskers are data points that may be outliers.
In this case, our data set has over 50,000 observations and we see many data points beyond the top whisker. We probably wouldn't want to classify all of those points as outliers, but the handful of diamonds at 4 carats and above are definitely far outside the norm.
One of the most useful features of a boxplot is the ability to make side-by-side boxplots. A side-by-side boxplot takes a numeric variable and splits it on based on some categorical variable, drawing a different boxplot for each level of the categorical variable. Let's make a side-by-side boxplot of diamond price split by diamond clarity:
In [9]:
diamonds.boxplot(column="price",        # Column to plot
                 by= "clarity",         # Column to split upon
                 figsize= (8,8))        # Figure size
Out[9]:
<matplotlib.axes._subplots.AxesSubplot at 0xb3534a8>
The boxplot above is curious: we'd expect diamonds with better clarity to fetch higher prices and yet diamonds on the highest end of the clarity spectrum (IF = internally flawless) actually have lower median prices than low clarity diamonds! What gives? Perhaps another boxplot can shed some light on this situation:
In [10]:
diamonds.boxplot(column="carat",        # Column to plot
                 by= "clarity",         # Column to split upon
                 figsize= (8,8))        # Figure size
Out[10]:
<matplotlib.axes._subplots.AxesSubplot at 0xba4c7f0>
The plot above shows that diamonds with low clarity ratings also tend to be larger. Since size is an important factor in determining a diamond's value, it isn't too surprising that low clarity diamonds have higher median prices.

Density Plots

A density plot shows the distribution of a numeric variable with a continuous curve. It is similar to a histogram but without discrete bins, a density plot gives a better picture of the underlying shape of a distribution. Create a density plot with series.plot(kind="density")
In [11]:
diamonds["carat"].plot(kind="density",  # Create density plot
                      figsize=(8,8),    # Set figure size
                      xlim= (0,5))      # Limit x axis values
Out[11]:
<matplotlib.axes._subplots.AxesSubplot at 0xb7f6588>

Barplots

Barplots are graphs that visually display counts of categorical variables. We can create a barplot by creating a table of counts for a certain variable using the pd.crosstab() function and then passing the counts to df.plot(kind="bar"):
In [12]:
carat_table = pd.crosstab(index=diamonds["clarity"], columns="count")
carat_table
Out[12]:
col_0count
clarity
I1741
IF1790
SI113065
SI29194
VS18171
VS212258
VVS13655
VVS25066
In [13]:
carat_table.plot(kind="bar",
                 figsize=(8,8))
Out[13]:
<matplotlib.axes._subplots.AxesSubplot at 0xba242b0>
You can use a two dimensional table to create a stacked barplot. Stacked barplots show the distribution of a second categorical variable within each bar:
In [14]:
carat_table = pd.crosstab(index=diamonds["clarity"], 
                          columns=diamonds["color"])

carat_table
Out[14]:
colorDEFGHIJ
clarity
I1421021431501629250
IF7315838568129914351
SI1208324262131197622751424750
SI213701713160915481563912479
VS17051281136421481169962542
VS2169724702201234716431169731
VVS125265673499958535574
VVS25539919751443608365131
In [15]:
carat_table.plot(kind="bar", 
                 figsize=(8,8),
                 stacked=True)
Out[15]:
<matplotlib.axes._subplots.AxesSubplot at 0xc2981d0>
A grouped barplot is an alternative to a stacked barplot that gives each stacked section its own bar. To make a grouped barplot, do not include the stacked argument (or set stacked=False):
In [16]:
carat_table.plot(kind="bar", 
                 figsize=(8,8),
                 stacked=False)
Out[16]:
<matplotlib.axes._subplots.AxesSubplot at 0xbce8208>

Scatterplots

Scatterplots are bivariate (two variable) plots that take two numeric variables and plot data points on the x/y plane. We saw an example of scatterplots in lesson 16 when we created a scatter plot matrix of the mtcars data set. To create a single scatterplot, use df.plot(kind="scatter"):
In [17]:
diamonds.plot(kind="scatter",     # Create a scatterplot
              x="carat",          # Put carat on the x axis
              y="price",          # Put price on the y axis
              figsize=(10,10),
              ylim=(0,20000))  
Out[17]:
<matplotlib.axes._subplots.AxesSubplot at 0xbd35f98>
Although the scatterplot above has many overlapping points, it still gives us some insight into the relationship between diamond carat weight and price: bigger diamonds are generally more expensive.

Line Plots

Line plots are charts used to show the change in a numeric variable based on some other ordered variable. Line plots are often used to plot time series data to show the evolution of a variable over time. Line plots are the default plot type when using df.plot() so you don't have to specify the kind argument when making a line plot in pandas. Let's create some fake time series data and plot it with a line plot
In [18]:
# Create some data
years = [y for y in range(1950,2016)]

readings = [(y+np.random.uniform(0,20)-1900) for y in years]

time_df = pd.DataFrame({"year":years,
                        "readings":readings})

# Plot the data
time_df.plot(x="year",
             y="readings",
             figsize=(9,9))
Out[18]:
<matplotlib.axes._subplots.AxesSubplot at 0xbe3bf60>

Saving Plots

If you want to save plots for later use, you can export the plot figure (plot information) to a file. First get the plot figure with plot.get_figure() and then save it to a file with figure.savefig("filename"). You can save plots to a variety of common image file formats, such as png, jpeg and pdf.
In [19]:
my_plot = time_df.plot(x="year",     # Create the plot and save to a variable
             y="readings",
             figsize=(9,9))

my_fig = my_plot.get_figure()            # Get the figure

my_fig.savefig("line_plot_example.png")  # Save to file

Wrap Up

Pandas plotting functions let you visualize and explore data quickly. Pandas plotting functions don't offer all the features of dedicated plotting package like matplotlib or ggplot, but they are often enough to get the job done.

Now that we have developed some tools to explore data, the remainder of this guide will focus on statistics and predictive modeling in Python

No comments:

Post a Comment

Note: Only a member of this blog may post a comment.