인공지능 공부/딥러닝 논문읽기
(코로나 바이러스 예측)Analysis of COVID-19 data using Python
앨런튜링_
2021. 9. 23. 19:06
mport urllib
import datetime as dt
from matplotlib import pyplot as plt
import matplotlib
import pandas as pd
import seaborn as sns
url = "https://covid.ourworldindata.org/data/ecdc/full_data.csv"
CVD = pd.read_csv(url)
print(CVD.head(5))
date location new_cases new_deaths total_cases total_deaths \
0 2019-12-31 Afghanistan 0.0 0.0 NaN NaN
1 2020-01-01 Afghanistan 0.0 0.0 NaN NaN
2 2020-01-02 Afghanistan 0.0 0.0 NaN NaN
3 2020-01-03 Afghanistan 0.0 0.0 NaN NaN
4 2020-01-04 Afghanistan 0.0 0.0 NaN NaN
weekly_cases weekly_deaths biweekly_cases biweekly_deaths
0 NaN NaN NaN NaN
1 NaN NaN NaN NaN
2 NaN NaN NaN NaN
3 NaN NaN NaN NaN
4 NaN NaN NaN NaN
print(CVD.dtypes)
date object
location object
new_cases float64
new_deaths float64
total_cases float64
total_deaths float64
weekly_cases float64
weekly_deaths float64
biweekly_cases float64
biweekly_deaths float64
dtype: object
#날짜의 데이트 형식을 변환
CVD['date'] = [dt.datetime.strptime(x, '%Y-%m-%d') for x in CVD['date']]
print(CVD.dtypes)
date datetime64[ns]
location object
new_cases float64
new_deaths float64
total_cases float64
total_deaths float64
weekly_cases float64
weekly_deaths float64
biweekly_cases float64
biweekly_deaths float64
dtype: object
countries = ['United States', 'Spain', 'Italy', 'South Korea']
CVD_country = CVD[CVD.location.isin(countries)]
CVD_country
date location new_cases new_deaths total_cases total_deaths weekly_cases weekly_deaths biweekly_cases biweekly_deaths
27241 2019-12-31 Italy 0.0 0.0 NaN NaN NaN NaN NaN NaN
27242 2020-01-01 Italy 0.0 0.0 NaN NaN NaN NaN NaN NaN
27243 2020-01-02 Italy 0.0 0.0 NaN NaN NaN NaN NaN NaN
27244 2020-01-03 Italy 0.0 0.0 NaN NaN NaN NaN NaN NaN
27245 2020-01-04 Italy 0.0 0.0 NaN NaN NaN NaN NaN NaN
... ... ... ... ... ... ... ... ... ... ...
56361 2020-11-25 United States 170293.0 2224.0 12591165.0 259925.0 1231363.0 11238.0 2333339.0 20242.0
56362 2020-11-26 United States 186589.0 2341.0 12777754.0 262266.0 1247947.0 11729.0 2376622.0 20466.0
56363 2020-11-27 United States 106091.0 1189.0 12883845.0 263455.0 1166018.0 10900.0 2329044.0 21025.0
56364 2020-11-28 United States 207913.0 1404.0 13091758.0 264859.0 1177814.0 10446.0 2352144.0 20514.0
56365 2020-11-29 United States 154893.0 1204.0 13246651.0 266063.0 1157213.0 10164.0 2341760.0 20463.0
1339 rows × 10 columns
CVD_country.set_index('date', inplace= True)
CVD_country
location new_cases new_deaths total_cases total_deaths weekly_cases weekly_deaths biweekly_cases biweekly_deaths
date
2019-12-31 Italy 0.0 0.0 NaN NaN NaN NaN NaN NaN
2020-01-01 Italy 0.0 0.0 NaN NaN NaN NaN NaN NaN
2020-01-02 Italy 0.0 0.0 NaN NaN NaN NaN NaN NaN
2020-01-03 Italy 0.0 0.0 NaN NaN NaN NaN NaN NaN
2020-01-04 Italy 0.0 0.0 NaN NaN NaN NaN NaN NaN
... ... ... ... ... ... ... ... ... ...
2020-11-25 United States 170293.0 2224.0 12591165.0 259925.0 1231363.0 11238.0 2333339.0 20242.0
2020-11-26 United States 186589.0 2341.0 12777754.0 262266.0 1247947.0 11729.0 2376622.0 20466.0
2020-11-27 United States 106091.0 1189.0 12883845.0 263455.0 1166018.0 10900.0 2329044.0 21025.0
2020-11-28 United States 207913.0 1404.0 13091758.0 264859.0 1177814.0 10446.0 2352144.0 20514.0
2020-11-29 United States 154893.0 1204.0 13246651.0 266063.0 1157213.0 10164.0 2341760.0 20463.0
1339 rows × 9 columns
CVD_country['mortality_rate'] = CVD_country['total_deaths']/CVD_country['total_cases']
<ipython-input-76-9aae57fb628b>:1: SettingWithCopyWarning:
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead
See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
CVD_country['mortality_rate'] = CVD_country['total_deaths']/CVD_country['total_cases']
CVD_country
location new_cases new_deaths total_cases total_deaths weekly_cases weekly_deaths biweekly_cases biweekly_deaths mortality_rate
date
2019-12-31 Italy 0.0 0.0 NaN NaN NaN NaN NaN NaN NaN
2020-01-01 Italy 0.0 0.0 NaN NaN NaN NaN NaN NaN NaN
2020-01-02 Italy 0.0 0.0 NaN NaN NaN NaN NaN NaN NaN
2020-01-03 Italy 0.0 0.0 NaN NaN NaN NaN NaN NaN NaN
2020-01-04 Italy 0.0 0.0 NaN NaN NaN NaN NaN NaN NaN
... ... ... ... ... ... ... ... ... ... ...
2020-11-25 United States 170293.0 2224.0 12591165.0 259925.0 1231363.0 11238.0 2333339.0 20242.0 0.020643
2020-11-26 United States 186589.0 2341.0 12777754.0 262266.0 1247947.0 11729.0 2376622.0 20466.0 0.020525
2020-11-27 United States 106091.0 1189.0 12883845.0 263455.0 1166018.0 10900.0 2329044.0 21025.0 0.020448
2020-11-28 United States 207913.0 1404.0 13091758.0 264859.0 1177814.0 10446.0 2352144.0 20514.0 0.020231
2020-11-29 United States 154893.0 1204.0 13246651.0 266063.0 1157213.0 10164.0 2341760.0 20463.0 0.020085
1339 rows × 10 columns
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(14,20))
CVD_country.groupby('location')['new_cases'].plot(ax=axes[0,0], legend = True)
CVD_country.groupby('location')['new_deaths'].plot(ax=axes[0,1], legend = True)
CVD_country.groupby('location')['total_cases'].plot(ax=axes[1,0], legend = True)
CVD_country.groupby('location')['total_deaths'].plot(ax=axes[1,1], legend = True)
location
Italy AxesSubplot(0.547727,0.125;0.352273x0.343182)
South Korea AxesSubplot(0.547727,0.125;0.352273x0.343182)
Spain AxesSubplot(0.547727,0.125;0.352273x0.343182)
United States AxesSubplot(0.547727,0.125;0.352273x0.343182)
Name: total_deaths, dtype: object
axes[0,0].set_title("New Cases")
axes[0,1].set_title("New Deaths")
axes[1,0].set_title("Total Cases")
axes[1,1].set_title("Total Deaths")
Text(0.5, 1.0, 'Total Deaths')
rint(CVD.isnull().sum())
date 0
location 0
new_cases 333
new_deaths 333
total_cases 3303
total_deaths 12940
weekly_cases 1132
weekly_deaths 1132
biweekly_cases 2637
biweekly_deaths 2637
dtype: int64
CVD.columns = ['date', 'Country', 'New Cases', 'New deaths', 'Total Cases', 'Total Deaths', 'weekly_cases','weekly_deaths','biweekly_cases','biweekly_deaths']
CVD
date Country New Cases New deaths Total Cases Total Deaths weekly_cases weekly_deaths biweekly_cases biweekly_deaths
0 2019-12-31 Afghanistan 0.0 0.0 NaN NaN NaN NaN NaN NaN
1 2020-01-01 Afghanistan 0.0 0.0 NaN NaN NaN NaN NaN NaN
2 2020-01-02 Afghanistan 0.0 0.0 NaN NaN NaN NaN NaN NaN
3 2020-01-03 Afghanistan 0.0 0.0 NaN NaN NaN NaN NaN NaN
4 2020-01-04 Afghanistan 0.0 0.0 NaN NaN NaN NaN NaN NaN
... ... ... ... ... ... ... ... ... ... ...
59349 2020-11-25 Zimbabwe 90.0 1.0 9398.0 274.0 453.0 14.0 788.0 19.0
59350 2020-11-26 Zimbabwe 110.0 0.0 9508.0 274.0 527.0 13.0 841.0 19.0
59351 2020-11-27 Zimbabwe 115.0 0.0 9623.0 274.0 577.0 9.0 927.0 19.0
59352 2020-11-28 Zimbabwe 91.0 1.0 9714.0 275.0 594.0 10.0 949.0 18.0
59353 2020-11-29 Zimbabwe 108.0 0.0 9822.0 275.0 650.0 10.0 1036.0 18.0
59354 rows × 10 columns
# ~부정하는 것
CVD_no_china = CVD.loc[~(CVD['Country'].isin(["China", "World"]))]
CVD_no_china
date Country New Cases New deaths Total Cases Total Deaths weekly_cases weekly_deaths biweekly_cases biweekly_deaths
0 2019-12-31 Afghanistan 0.0 0.0 NaN NaN NaN NaN NaN NaN
1 2020-01-01 Afghanistan 0.0 0.0 NaN NaN NaN NaN NaN NaN
2 2020-01-02 Afghanistan 0.0 0.0 NaN NaN NaN NaN NaN NaN
3 2020-01-03 Afghanistan 0.0 0.0 NaN NaN NaN NaN NaN NaN
4 2020-01-04 Afghanistan 0.0 0.0 NaN NaN NaN NaN NaN NaN
... ... ... ... ... ... ... ... ... ... ...
59349 2020-11-25 Zimbabwe 90.0 1.0 9398.0 274.0 453.0 14.0 788.0 19.0
59350 2020-11-26 Zimbabwe 110.0 0.0 9508.0 274.0 527.0 13.0 841.0 19.0
59351 2020-11-27 Zimbabwe 115.0 0.0 9623.0 274.0 577.0 9.0 927.0 19.0
59352 2020-11-28 Zimbabwe 91.0 1.0 9714.0 275.0 594.0 10.0 949.0 18.0
59353 2020-11-29 Zimbabwe 108.0 0.0 9822.0 275.0 650.0 10.0 1036.0 18.0
58684 rows × 10 columns
CVD_no_china = pd.DataFrame(CVD_no_china.groupby(['Country', 'date'])['Total Cases', 'Total Deaths'].sum()).reset_index()
<ipython-input-105-74dc02810ca5>:1: FutureWarning: Indexing with multiple keys (implicitly converted to a tuple of keys) will be deprecated, use a list instead.
CVD_no_china = pd.DataFrame(CVD_no_china.groupby(['Country', 'date'])['Total Cases', 'Total Deaths'].sum()).reset_index()
CVD_no_china
Country date Total Cases Total Deaths
0 Afghanistan 2019-12-31 0.0 0.0
1 Afghanistan 2020-01-01 0.0 0.0
2 Afghanistan 2020-01-02 0.0 0.0
3 Afghanistan 2020-01-03 0.0 0.0
4 Afghanistan 2020-01-04 0.0 0.0
... ... ... ... ...
58679 Zimbabwe 2020-11-25 9398.0 274.0
58680 Zimbabwe 2020-11-26 9508.0 274.0
58681 Zimbabwe 2020-11-27 9623.0 274.0
58682 Zimbabwe 2020-11-28 9714.0 275.0
58683 Zimbabwe 2020-11-29 9822.0 275.0
58684 rows × 4 columns
CVD_no_china = CVD_no_china.sort_values(by = ['Country', 'date'], ascending=False)
CVD_no_china
Country date Total Cases Total Deaths
58683 Zimbabwe 2020-11-29 9822.0 275.0
58682 Zimbabwe 2020-11-28 9714.0 275.0
58681 Zimbabwe 2020-11-27 9623.0 274.0
58680 Zimbabwe 2020-11-26 9508.0 274.0
58679 Zimbabwe 2020-11-25 9398.0 274.0
... ... ... ... ...
4 Afghanistan 2020-01-04 0.0 0.0
3 Afghanistan 2020-01-03 0.0 0.0
2 Afghanistan 2020-01-02 0.0 0.0
1 Afghanistan 2020-01-01 0.0 0.0
0 Afghanistan 2019-12-31 0.0 0.0
58684 rows × 4 columns
#plot 함수만들자
def plot_bar(feature, value, title, df, size):
f, ax = plt.subplots(1,1, figsize=(4*size,4))
df = df.sort_values([value], ascending=False).reset_index(drop=True)
g = sns.barplot(df[feature][0:10], df[value][0:10], palette='Set2')
g.set_title("Number of {} - highest 10 values".format(title))
# ax.set_xticklabels(ax.get_xticklabels(),rotation=90)
plt.show()
filtered_CVD_no_china = CVD_no_china.drop_duplicates(subset = ['Country'], keep='first')
plot_bar('Country', 'Total Cases', 'Total cases in the World except China', filtered_CVD_no_china, size=4)
plot_bar('Country', 'Total Deaths', 'Total deaths in the World except China', filtered_CVD_no_china, size=4)
filtered_CVD_no_china
Country date Total Cases Total Deaths
58683 Zimbabwe 2020-11-29 9822.0 275.0
58429 Zambia 2020-11-29 17589.0 357.0
58173 Yemen 2020-11-29 2160.0 615.0
57939 Western Sahara 2020-11-29 766.0 1.0
57721 Wallis and Futuna 2020-11-29 3.0 0.0
... ... ... ... ...
1455 Angola 2020-11-29 15087.0 345.0
1202 Andorra 2020-11-29 6670.0 76.0
935 Algeria 2020-11-29 81212.0 2393.0
600 Albania 2020-11-29 36790.0 787.0
334 Afghanistan 2020-11-29 45844.0 1763.0
213 rows × 4 columns
def plot_world_aggregate(df, title='Aggregate plot', size=1):
f, ax = plt.subplots(1,1, figsize=(4*size,2*size))
g = sns.lineplot(x="date", y='Total Cases', data=df, color='blue', label='Total Cases')
g = sns.lineplot(x="date", y='Total Deaths', data=df, color='red', label='Total Deaths')
plt.xlabel('date')
plt.ylabel(f'Total {title} cases')
plt.xticks(rotation=90)
plt.title(f'Total {title} cases')
ax.grid(color='black', linestyle='dotted', linewidth=0.75)
plt.show()
CVD_no_china_aggregate = CVD_no_china.groupby(['date']).sum().reset_index()
CVD_no_china_aggregate
date Total Cases Total Deaths
0 2019-12-31 0.0 0.0
1 2020-01-01 0.0 0.0
2 2020-01-02 0.0 0.0
3 2020-01-03 0.0 0.0
4 2020-01-04 0.0 0.0
... ... ... ...
330 2020-11-25 59810576.0 1406042.0
331 2020-11-26 60460221.0 1418625.0
332 2020-11-27 61010116.0 1429263.0
333 2020-11-28 60612242.0 1412281.0
334 2020-11-29 59732850.0 1380975.0
335 rows × 3 columns
plot_world_aggregate(CVD_no_china_aggregate, 'Rest of the World except China', size=4)
def plot_aggregate_countries(df, countries, case_type='Total Cases', size=3, is_log=False):
f, ax = plt.subplots(1,1, figsize=(4*size, 3*size))
for country in countries:
df_ = df[(df['Country']==country) & (df['date'] > '2020-02-15')]
g = sns.lineplot(x="date", y=case_type, data=df_, label=country)
ax.text(max(df_['date']), max(df_[case_type]), str(country))
plt.xlabel('date')
plt.ylabel(f' {case_type} ')
plt.title(f' {case_type} ')
plt.xticks(rotation=90)
if(is_log):
ax.set(yscale="log")
ax.grid(color='black', linestyle='dotted', linewidth=0.75)
plt.show()
CVD_country_aggregate = CVD_no_china.groupby(['Country', 'date']).sum().reset_index()
countries = ["United States", "Italy", "Spain", "South Korea",
"France", "Germany", "Switzerland", "India"]
plot_aggregate_countries(CVD_country_aggregate, countries, case_type = 'Total Cases', size=4)
plot_aggregate_countries(CVD_country_aggregate, countries, case_type = 'Total Deaths', size=4)
#log scale
plot_aggregate_countries(CVD_country_aggregate, countries, case_type = 'Total Cases', size=4, is_log=True)
def plot_mortality(df, title='Mainland China', size=1):
f, ax = plt.subplots(1,1, figsize=(4*size,2*size))
g = sns.lineplot(x="date", y='Mortality (Deaths/Cases)', data=df, color='blue', label='Mortality (Deaths / Total Cases)')
plt.xlabel('date')
plt.ylabel(f'Mortality {title} [%]')
plt.xticks(rotation=90)
plt.title(f'Mortality percent {title}\nCalculated as Deaths/Confirmed cases')
ax.grid(color='black', linestyle='dashed', linewidth=1)
plt.show()
CVD_no_china_aggregate['Mortality (Deaths/Cases)'] = CVD_no_china_aggregate['Total Deaths'] / CVD_no_china_aggregate['Total Cases'] * 100
plot_mortality(CVD_no_china_aggregate, title = ' - Rest of the World except China', size = 3)