본문 바로가기
데이터 시각화

Python 데이터 시각화 - 상관관계(correlation) 분석, Heatmap (with Matplotlib, seaborn )

by 맑은안개 2021. 3. 30.

출처: Pixabay

데이터 상관관계 시각화

Pandas dataframe 구조의 데이터를 사용하여 데이터 간의 상관관계(correlation)를 확인해보고 이를 Heatmap으로 시각화하여 표현해보자. 

 

샘플데이터는 Plotly에서 제공하는 샘플 데이터를 사용하도록 한다.


상관관계 분석(Correlation Analysis)

상관관계 분석은 두 개이상의 변수 사이에 존재하는 상호 연관성 존재와 그 강도를 측정하는 방법이다.

개발환경

- Pandas 3.9.0

- Jupyter-lab 3.0

- Plotly 4.14.3

- matplotlib 3.2.2

- seaborn 0.11.1

임포트 라이브러리

# Data
import numpy as np
import pandas as pd

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go

#warning
import warnings
warnings.filterwarnings('ignore')

샘플데이터 로드 

stocks = px.data.stocks()
stocks
	date	GOOG	AAPL	AMZN	FB	NFLX	MSFT
0	2018-01-01	1.000000	1.000000	1.000000	1.000000	1.000000	1.000000
1	2018-01-08	1.018172	1.011943	1.061881	0.959968	1.053526	1.015988
2	2018-01-15	1.032008	1.019771	1.053240	0.970243	1.049860	1.020524
3	2018-01-22	1.066783	0.980057	1.140676	1.016858	1.307681	1.066561
4	2018-01-29	1.008773	0.917143	1.163374	1.018357	1.273537	1.040708
...	...	...	...	...	...	...	...
100	2019-12-02	1.216280	1.546914	1.425061	1.075997	1.463641	1.720717
101	2019-12-09	1.222821	1.572286	1.432660	1.038855	1.421496	1.752239
102	2019-12-16	1.224418	1.596800	1.453455	1.104094	1.604362	1.784896
103	2019-12-23	1.226504	1.656000	1.521226	1.113728	1.567170	1.802472
104	2019-12-30	1.213014	1.678000	1.503360	1.098475	1.540883	1.788185
105 rows × 7 columns

상관관계 분석

corr_df = stocks.corr()
corr_df = corr_df.apply(lambda x: round(x ,2))
corr_df
	GOOG	AAPL	AMZN	FB	NFLX	MSFT
GOOG	1.00	0.83	0.56	0.63	0.14	0.75
AAPL	0.83	1.00	0.56	0.49	0.05	0.79
AMZN	0.56	0.56	1.00	0.34	0.62	0.66
FB	0.63	0.49	0.34	1.00	0.27	0.47
NFLX	0.14	0.05	0.62	0.27	1.00	0.08
MSFT	0.75	0.79	0.66	0.47	0.08	1.00

- 1에 가까울 수록 관계가 깊으며(강한 양의 상관관계) 0에 가까울 수록 관계가 적다(약한 양의 상관관계). 

- 수치 상으로 볼 때 구글(GOOG)주식은 애플(AAPL)의 주식 흐름과 밀접한 관계에 있다고 할 수 있다. 넷플릭스(NFLX)는 구글주식과는 관계가 적으며 아마존(AMZN)과 비교적 높은 상관관계 계수를 갖는다 할 수 있다. 

- 아래 데이터 시각화를 통해 각 기업주식의 관계를 직관적이며 쉽게 이해할 수 있게 해보자. 

s = corr_df.unstack()
s
GOOG  GOOG    1.00
      AAPL    0.83
      AMZN    0.56
      FB      0.63
      NFLX    0.14
      MSFT    0.75
AAPL  GOOG    0.83
      AAPL    1.00
      AMZN    0.56
      FB      0.49
      NFLX    0.05
      MSFT    0.79
AMZN  GOOG    0.56
      AAPL    0.56
      AMZN    1.00
      FB      0.34
      NFLX    0.62
      MSFT    0.66
      ... 생략 ...

- unstack()을 활용하여 매트릭스가 아닌 Series형태로 표현할 수 있다. 

- DataFrame의 style프로퍼티를 사용하여 출력 데이터를 시각화할 수 있다. 

# Series이므로 DataFrame으로 변경한다. 
df = pd.DataFrame(s[s < 1].sort_values(ascending=False), columns=['corr'])
df.style.background_gradient(cmap='viridis')

- 관계계수가 높은 순(Descending)부터 정렬하여 Styling처리. ( 1 은 자신과의 관계이므로 제외 )

DataFrame Styling

Heatmap 데이터 시각화

Matplotlib

fig, ax = plt.subplots()
im = ax.imshow(corr_df, cmap='Greys')

# Color Bar
cbar = ax.figure.colorbar(im, ax=ax)

ax.set_xticks(np.arange(len(corr_df.columns)))
ax.set_yticks(np.arange(len(corr_df.index)))

ax.set_xticklabels(corr_df.columns)
ax.set_yticklabels(corr_df.columns)

for x in range(len(corr_df.columns)):
    for y in range(len(corr_df.index)):
        ax.text(y, x, corr_df.iloc[y, x], ha='center', va='center', color='g')

fig.tight_layout()   
plt.show()    

- imshow()을 사용하여 heatmap을 그린다.

- cmap 매개변수를 입력하여 heatmap에서 표현할 Color를 지정할 수 있다. ( Color 목록은 여기를 참조 )

- ax.text()를 사용하여 관계계수를 텍스트로 표현한다. 

heatmap by matplotlib

seaborn

- seaborn을 사용하여 더 간편하게 Heatmap을 그릴수 있다. 

ax = sns.heatmap(corr_df, annot=True, annot_kws=dict(color='g'), cmap='Greys')
plt.show()

heatmap by seaborn

 

특정 컬럼 기준, n번째 상/하위 컬럼 추출하기

- nlargest, nsmallest 를 사용하여 지정한 컬럼 기준 상위, 하위 데이터를 추출할 수 있다. 이를 사용하여 상관관계지수가 높거나 낮은 데이터를 추출한다.

- sort_values와 head를 조합한 방법 보다 더 나은 성능을 제공한다.

This method is equivalent to df.sort_values(columns, ascending=False).head(n), but more performant.
corr5 = corr_df.nlargest(5, 'GOOG')
corr5 = corr5[list(corr5.index)]
corr5
	GOOG	AAPL	MSFT	FB	AMZN
GOOG	1.00	0.83	0.75	0.63	0.56
AAPL	0.83	1.00	0.79	0.49	0.56
MSFT	0.75	0.79	1.00	0.47	0.66
FB	0.63	0.49	0.47	1.00	0.34
AMZN	0.56	0.56	0.66	0.34	1.00

 

corr3 = corr_df.nlargest(3, 'GOOG')
corr3 = corr3[list(corr3.index)]
corr3
	GOOG	AAPL	MSFT
GOOG	1.00	0.83	0.75
AAPL	0.83	1.00	0.79
MSFT	0.75	0.79	1.00
반응형