Scatterplot Matrix in Python

How to make scatterplot matrices or sploms natively in Python with Plotly.


Plotly Studio: Transform any dataset into an interactive data application in minutes with AI. Try Plotly Studio now.

Scatter matrix with Plotly Express

A scatterplot matrix is a matrix associated to n numerical arrays (data variables), $X_1,X_2,…,X_n$ , of the same length. The cell (i,j) of such a matrix displays the scatter plot of the variable Xi versus Xj.

Here we show the Plotly Express function px.scatter_matrix to plot the scatter matrix for the columns of the dataframe. By default, all columns are considered.

Plotly Express is the easy-to-use, high-level interface to Plotly, which operates on a variety of types of data and produces easy-to-style figures.

In [1]:
import plotly.express as px
df = px.data.iris()
fig = px.scatter_matrix(df)
fig.show()

Specify the columns to be represented with the dimensions argument, and set colors using a column of the dataframe:

In [2]:
import plotly.express as px
df = px.data.iris()
fig = px.scatter_matrix(df,
    dimensions=["sepal_length", "sepal_width", "petal_length", "petal_width"],
    color="species")
fig.show()

Styled Scatter Matrix with Plotly Express

The scatter matrix plot can be configured thanks to the parameters of px.scatter_matrix, but also thanks to fig.update_traces for fine tuning (see the next section to learn more about the options).

In [3]:
import plotly.express as px
df = px.data.iris()
fig = px.scatter_matrix(df,
    dimensions=["sepal_length", "sepal_width", "petal_length", "petal_width"],
    color="species", symbol="species",
    title="Scatter matrix of iris data set",
    labels={col:col.replace('_', ' ') for col in df.columns}) # remove underscore
fig.update_traces(diagonal_visible=False)
fig.show()

Scatter matrix (splom) with go.Splom

If Plotly Express does not provide a good starting point, it is possible to use the more generic go.Splom class from plotly.graph_objects. All its parameters are documented in the reference page https://plotly.com/python/reference/splom/.

The Plotly splom trace implementation for the scatterplot matrix does not require to set $x=Xi$ , and $y=Xj$, for each scatter plot. All arrays, $X_1,X_2,…,X_n$ , are passed once, through a list of dicts called dimensions, i.e. each array/variable represents a dimension.

A trace of type splom is defined as follows:

trace=go.Splom(dimensions=[dict(label='string-1',
                                values=X1),
                           dict(label='string-2',
                                values=X2),
                           .
                           .
                           .
                           dict(label='string-n',
                                values=Xn)],
                           ....
               )

The label in each dimension is assigned to the axes titles of the corresponding matrix cell.

Splom of the Iris data set

In [4]:
import plotly.graph_objects as go
import pandas as pd

df = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/iris-data.csv')

# The Iris dataset contains four data variables, sepal length, sepal width, petal length,
# petal width, for 150 iris flowers. The flowers are labeled as `Iris-setosa`,
# `Iris-versicolor`, `Iris-virginica`.

# Define indices corresponding to flower categories, using pandas label encoding
index_vals = df['class'].astype('category').cat.codes

fig = go.Figure(data=go.Splom(
                dimensions=[dict(label='sepal length',
                                 values=df['sepal length']),
                            dict(label='sepal width',
                                 values=df['sepal width']),
                            dict(label='petal length',
                                 values=df['petal length']),
                            dict(label='petal width',
                                 values=df['petal width'])],
                text=df['class'],
                marker=dict