I try a problem from This book, James Stewart, Lothar Redlin, Saleem Watson – Algebra and Trigonometry , Third Edition -Brooks Cole (2011)
In [1]:
import matplotlib.pyplot as plt
import numpy as np
data = {'year': np.arange(1790, 2001, 10),
'population': np.array([3.9, 5.3, 7.2, 9.6, 12.9, 17.1, 23.2, 31.4,
38.6, 50.2, 63.0, 76.2, 92.2, 106.0, 123.2, 132.2,
151.3, 179.3, 203.3, 226.5, 248.7, 281.4])}
In [2]:
data['year']
Out[2]:
array([1790, 1800, 1810, 1820, 1830, 1840, 1850, 1860, 1870, 1880, 1890, 1900, 1910, 1920, 1930, 1940, 1950, 1960, 1970, 1980, 1990, 2000])
In [3]:
data['population']
Out[3]:
array([ 3.9, 5.3, 7.2, 9.6, 12.9, 17.1, 23.2, 31.4, 38.6, 50.2, 63. , 76.2, 92.2, 106. , 123.2, 132.2, 151.3, 179.3, 203.3, 226.5, 248.7, 281.4])
In [4]:
fig, ax = plt.subplots()
ax.scatter('year', 'population', data=data)
ax.set_xlabel('year')
ax.set_ylabel('population (in millions)')
Out[4]:
Text(0, 0.5, 'population (in millions)')
- Import the required libraries
In [8]:
from scipy.optimize import curve_fit
- Define the exponential function:
In [9]:
def exponential_model(x, A, B):
return A * np.exp(B * x)
- Prepare data:
Convert the year values to a form that will be useful for the regression
In [10]:
# Center the years around zero for numerical stability
years_centered = data['year'] - np.mean(data['year'])
- Perform the curve fitting:
In [11]:
# Fit the exponential model to the data
params, covariance = curve_fit(exponential_model, years_centered, data['population'], p0=(1, 0.01))
- Generate the fitted curve:
In [12]:
# cretate a range of x values for plotting the fitted curve
x_fit = np.linspace(min(years_centered), max(years_centered), 100)
y_fit = exponential_model(x_fit, *params)
In [13]:
plt.figure(figsize=(10, 6))
plt.scatter(data['year'], data['population'], color='blue', label='Data')
plt.plot(data['year'], exponential_model(years_centered, *params), color='red', label='Exponential Fit')
plt.xlabel('Year')
plt.ylabel('Population (in billions)')
plt.title('Exponential Regression')
plt.legend()
plt.grid(True)
plt.show()
In [ ]: