카테고리 없음

[딥러닝] 텐서플로우(TensorFlow) - 1. 데이터 전 처리_Epoch,Batch

오기오기 2021. 10. 26. 00:36
728x90
반응형

단층 퍼셉트론이 여러개 모여 히든층이 많아졌을때 우리가 흔히 들어본 딥러닝이라고 합니다

https://1percent-a-day.tistory.com/13

 

딥러닝 퍼셉트론(perceptron)_초기인공신경망

딥러닝은 머신러닝 방법론 중 하나로 인공신경망에 기반하여 컴퓨터에게 학습하는 방법입니다. 여기서 말하는 인공신경망이란 인간의 신경 시스템을 모방하여 만들어진 학습알고리즘입니다

1percent-a-day.tistory.com

 

텐서플로우

가장 대표적인 딥러닝 프레임 워크인 텐서플로우는 대형 클러스터 컴퓨터부터 스마트폰까지 다양한 디바이스에서 동작이 가능합니다 

딥러닝 모델을 구현하기 위한 첫번째단계는 데이터 전 처리하기 입니다

 

Tensorflow 딥러닝 모델은 Tensor 형태의 데이터를 입력받게되는데, 여기서 Tensor는 다차원 배열로서 tensorflow에서 사용하는 객체라고 생각하면 됩니다

 

딥러닝에서 사용하는 모델은 추가적인 전 처리 과정이 필요하기에 Epoch와 Batch를 사용합니다

  • Epoch 전체의데이터 셋에 대해 한번 학습을 완료한 상태 
  • Batch 나눠진 데이터 셋 (딥러닝이 학습과정에서 가중치(W)들을 업테이트할때 계산량을 줄이기 위해 Batch사용)

Data set을 Batch size로 나눈 값이 1 Epoch이며 몇번 iteration했는지 알 수 있게 됩니다 

 

 

텐서플로우 신경망 모델의 학습 데이터는 기존 데이터를 tf.data.Dataset 형식으로 변환할 때 from_tensor_slices() 메서드를 사용합니다. 

 ds = tf.data.Dataset.from_tensor_slices((X.values, Y.values))

 

변환된 Dataset인 ds에서 batch를 적용하고 싶다면 아래와 같이 batach() 메서드를 사용합니다.

ds = ds.shuffle(len(X)).batch(batch_size=5)

 

take() 메서드를 사용하면 batch로 분리된 데이터를 확인

 

 

마케터로서 광고 비용에 따른 수익률을 신경망을 통해서 예측해봅시다

 

import tensorflow as tf
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split

np.random.seed(100)
tf.random.set_seed(100)

# 데이터를 DataFrame 형태로 불러 옵니다.
df = pd.read_csv("data/Advertising.csv")

# DataFrame 데이터 샘플 5개를 출력합니다.
print('원본 데이터 샘플 :')
print(df.head(),'\n')

# 의미없는 변수는 삭제합니다.
df = df.drop(columns=['Unnamed: 0'])

"""
1. Sales 변수는 label 데이터로 Y에 저장하고 나머진 X에 저장합니다.
"""
X = df.drop(columns=['Sales'])
Y = df['Sales']

train_X, test_X, train_Y, test_Y = train_test_split(X, Y, test_size=0.3)

"""
2. 학습용 데이터를 tf.data.Dataset 형태로 변환합니다.
   from_tensor_slices 함수를 사용하여 변환하고 batch를 수행하게 합니다.
"""
train_ds = tf.data.Dataset.from_tensor_slices((train_X.values, train_Y.values))
train_ds = train_ds.shuffle(len(train_X)).batch(batch_size=5)

# 하나의 batch를 뽑아서 feature와 label로 분리합니다.
[(train_features_batch, label_batch)] = train_ds.take(1)

# batch 데이터를 출력합니다.
print('\nFB, TV, Newspaper batch 데이터:\n',train_features_batch)
print('Sales batch 데이터:',label_batch)

 

출처: 2021 nipa ai 온라인 교육

728x90
반응형