🐮

【EDA Method】How to plot correlation of DataFrame columns

2024/08/16に公開

Plot correlation

This time, introduce how to plot the correlation of columns of dataframe.

It's so useful. Try it to visualization and data analysis.

def plot_target_correlation(
    df: pd.DataFrame,
    target_col: str = 'target',
    n_top_features: int = 30,
    color_sequence: list[str] | None = None,
    template_theme: str = "plotly_white"
) -> None:
    """Create a correlation plot showing the top correlated features with the target variable.

    Args:
        df (pd.DataFrame): 
            The input dataframe containing feature and target columns.
        target_col (str, optional): 
            Name of the target column.
        n_top_features (int, optional): 
            Number of top correlated features to display.
        color_sequence (list[str], optional): 
            Custom color sequence for the plot.
            Defaults to None, which uses a blue to red color scale.
        template_theme (str, optional): 
            Plotly template theme for visual styling.

    Returns:
        None: This function displays a plot and doesn't return anything.
    """
    # Calculate correlations
    correlations = df.corr()[target_col]
    
    # Sort by absolute correlation value
    correlations_abs = correlations.abs().sort_values(ascending=False)
    
    # Select top correlated features (excluding the target itself)
    top_correlations = correlations[correlations_abs.index[1:n_top_features+1]]
    
    # Prepare data for plotting
    feature_names = top_correlations.index
    correlation_values = top_correlations.values
    
    # Set up color scale
    if color_sequence is None:
        color_sequence = ['#0d0887', '#46039f', '#7201a8', '#9c179e', '#bd3786', '#d8576b', '#ed7953', '#fb9f3a', '#fdca26', '#f0f921']
    
    # Create the bar plot
    fig = go.Figure()
    fig.add_trace(go.Bar(
        y=feature_names,
        x=correlation_values,
        orientation='h',
        marker=dict(
            color=correlation_values,
            colorscale=color_sequence,
            colorbar=dict(title="Correlation"),
        )
    ))
    
    # Customize the layout
    fig.update_layout(
        title=f"<b>Top {n_top_features} Features Correlated with {target_col.capitalize()}</b>",
        xaxis_title="<b>Correlation Coefficient</b>",
        yaxis_title="<b>Feature</b>",
        height=800,
        width=1200,
        template=template_theme,
    )
    
    # Add vertical line at x=0 for reference
    fig.add_shape(
        type="line",
        x0=0, y0=-0.5,
        x1=0, y1=len(feature_names) - 0.5,
        line=dict(color="black", width=1, dash="dash")
    )
    
    # Display the plot
    fig.show()
    
plot_target_correlation(preprocessed_train_df)


Discussion