pythonmatplotlibscikit-learnyellowbrick

Highlighting specific data points for parallel coordinates plot


I'm looking for help to highlight/color particular data points on the parallel coordinates plot. I can't seem to find a way that work.

Essentially, I want to plot all the data as below, and then take, e.g., index [0, 1, 2] of the data points and color them a third color to highlight them (and if possible also make them thicker?) Any suggestions?

from sklearn import datasets
from yellowbrick.features import ParallelCoordinates

iris = datasets.load_iris()
X = iris.data[:, :]
y = iris.target

features = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width']
classes = ['Iris-setosa', 'Iris-versicolor', 'Iris-virginica']
title = "Plot over Iris Data"

# Instantiate the visualizer
visualizer = ParallelCoordinates(
    classes=classes, features=features, fast=False, alpha=.40, title=title)

# Fit the visualizer and display it
visualizer.fit_transform(X, y)
visualizer.finalize()  # creates title, legend, etc.

visualizer.ax.tick_params(labelsize=22)  # change size of tick labels
visualizer.ax.title.set_fontsize(30)  # change size of title

for text in visualizer.ax.legend_.texts:  # change size of legend texts
     text.set_fontsize(20)

visualizer.fig.tight_layout()  # fit all texts nicely into the surrounding figure
visualizer.fig.show()

Solution

  • Currently, ParallelCoordinates.draw() iterates the datapoints in order. Hence, the child Line2D instances of visualizer.ax will follow the order of the data. Hence, you can do:

    from sklearn import datasets
    from yellowbrick.features import ParallelCoordinates
    
    # New code ----------------------
    import matplotlib.pyplot as plt
    special_lines = [0, 1, 2]
    # Put any property you want here.
    special_properties = {'linestyle': '--', 'color': 'k', 
                          'linewidth': 5, 'zorder': float('inf'), 
                          'alpha': 1}
    # End of new code ---------------
    
    iris = datasets.load_iris()
    X = iris.data[:, :]
    y = iris.target
    
    features = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width']
    classes = ['Iris-setosa', 'Iris-versicolor', 'Iris-virginica']
    title = "Plot over Iris Data"
    
    # Instantiate the visualizer
    visualizer = ParallelCoordinates(
        classes=classes, features=features, fast=False, alpha=.40, title=title)
    
    # Fit the visualizer and display it
    visualizer.fit_transform(X, y)
    
    # New code ----------------------
    for line in [visualizer.ax.get_lines()[i] for i in special_lines]:
        plt.setp(line, **special_properties)
    # End of new code ---------------
    
    visualizer.finalize()  # creates title, legend, etc.
    
    visualizer.ax.tick_params(labelsize=22)  # change size of tick labels
    visualizer.ax.title.set_fontsize(30)  # change size of title
    
    for text in visualizer.ax.legend_.texts:  # change size of legend texts
         text.set_fontsize(20)
            
    visualizer.fig.tight_layout()  # fit all texts nicely into the surrounding figure
    visualizer.fig.show()
    

    Result:

    enter image description here

    Please note that the fact that lines are added in-order is not written in the documentation, it's just how it is implemented. Hence, it could be (even though I don't expect it) that they will change this behavior in future updates. A safer way would be to manually check whether the line's data matches the transformed data used by the visualizer. Note that we need to use the transformed data in general, because ParallelCoordinates also implements a normaliser. It's not your case, but in general we should do so:

    # Perform AFTER visualizer.fit_transform(X, y).
    import numpy as np
    
    transformed_data = list(visualizer.transform(X[special_lines, :]))
    for line in visualizer.ax.get_lines():
        for i, arr in enumerate(transformed_data[:]): 
            if np.array_equal(arr, line.get_data()[1]):
                plt.setp(line, **special_properties)
                break