Keywords: Python | scatter plot | linear fit | data visualization | matplotlib
Abstract: This article provides a detailed exploration of multiple methods for overlaying linear fit lines on scatter plots in Python. Starting with fundamental implementation using numpy.polyfit, it compares alternative approaches including seaborn's regplot and statsmodels OLS regression. Complete code examples, parameter explanations, and visualization analysis help readers deeply understand linear regression applications in data visualization.
Introduction
In data analysis and visualization, combining scatter plots with linear fit lines is a fundamental technique for exploring variable relationships. This article systematically demonstrates how to implement this functionality in Python, focusing on the numpy and matplotlib libraries.
Basic Method: Using numpy.polyfit
The most straightforward approach employs the numpy.polynomial.polynomial.polyfit function for linear fitting. This function uses least squares method to calculate optimal fit line parameters.
import numpy as np
from numpy.polynomial.polynomial import polyfit
import matplotlib.pyplot as plt
# Generate sample data
x = np.arange(10)
y = 5 * x + 10
# Perform linear fit, 1 indicates first-degree polynomial (straight line)
b, m = polyfit(x, y, 1)
# Create scatter plot
plt.plot(x, y, '.', label='Original data points')
# Overlay fitted line
plt.plot(x, b + m * x, '-', label='Linear fit')
plt.legend()
plt.show()
Code explanation: polyfit(x, y, 1) returns intercept b and slope m, where parameter 1 specifies first-degree polynomial fitting. The fitted line equation is y = mx + b.
Advanced Visualization: Seaborn Integrated Solution
For rapid prototyping scenarios, the seaborn library offers a more concise solution.
import numpy as np
import seaborn as sns
# Generate linear data with noise
N = 100
x = np.random.rand(N)
y = 3 * x + np.random.rand(N)
# Single-line implementation of scatter plot with fit line
sns.regplot(x=x, y=y)
regplot automatically computes and plots linear regression lines while displaying confidence intervals, suitable for exploratory data analysis.
Statistical Modeling: statsmodels Detailed Analysis
When complete statistical information is required, the statsmodels library provides professional regression analysis capabilities.
import statsmodels.api as sm
import numpy as np
import matplotlib.pyplot as plt
X = np.random.rand(100)
Y = X + np.random.rand(100)*0.1
# Add constant term to include intercept
X_with_const = sm.add_constant(X)
# Perform ordinary least squares regression
results = sm.OLS(Y, X_with_const).fit()
# Output detailed statistical results
print(results.summary())
# Visualization
plt.scatter(X, Y)
X_plot = np.linspace(0, 1, 100)
plt.plot(X_plot, X_plot * results.params[1] + results.params[0])
plt.show()
This method not only provides fitted lines but also outputs complete statistical metrics including R-squared and standard errors, suitable for rigorous data analysis work.
Performance Comparison and Selection Guidelines
Each method has distinct advantages: numpy.polyfit suits basic applications and custom requirements; seaborn fits rapid visualization; statsmodels addresses scenarios requiring statistical validation. Selection should balance simplicity and functionality based on specific needs.
Best Practices and Considerations
In practical applications, recommendations include: 1) Always check residual plots to verify linear assumptions; 2) Consider computational efficiency for large datasets; 3) Use plt.xlabel() and plt.ylabel() for clear axis labeling. Through appropriate tool and method selection, professional-level data visualization effects can be efficiently achieved.