본문 바로가기

AI/AI 라이브러리

[Tensorflow] 모델 컴파일, 학습, 예측

728x90
반응형

저번 포스트에서는 Tensorflow를 사용해 모델을 만드는 방법을 알아보았다. 

이제 모델을 만들었으니 모델을 사용하기 위해 컴파일, 학습, 예측의 순서대로 알아보자.

 

Complie

Tensorflow에서 모델 complie은 "모델 인스턴스.comple()"을 사용한다. 이 때 여러가지 속성을 적용할 수 있다.

 

주요 속성은 다음과 같다.

1. optimizer

  • 최적화기
  • 모델의 가중치를 업데이트하는 최적화 알고리즘을 지정
  • adam, sgd, rmsprop 등

 

2. loss 

  • 손실함수
  • 모델의 학습 중에 최소화될 손실 함수를 지정
  • categorical_crossentropy, binary_crossentropy, mse 등

 

3. metrics

  • 평가지표
  • 모델의 성능을 평가하기 위한 지표를 지정
  • accuracy, precision, recall 등
# 모델 컴파일
model.compile(optimizer='sgd',loss='mse')

 

이렇게 compile을 하면 이제 학습(fitting)을 할 수 있는 상태가 된다.

 

 

Fitting

모델 학습을 해보자. 

complie 된 모델에 X, Y, epochs, (선택 : batch_size, validation_data, callbacks)를 넣어 모델을 학습 한다.

 

주요 속성은 다음과 같다

1.X

  • 입력 데이터 셋
  • 훈련 데이터 셋을 넣음.

 

2. Y

  • target 데이터 셋(지도 학습의 경우)
  • 훈련 데이터 셋을 넣음

 

3. epochs

  • 훈련의 반복의 수

 

4. validation_data

  • 검증 데이터 셋을 지정

 

5. batch_size

  • 한번 epoch에서 사용되는 데이터의 개수(설정 안하면 전체 데이터 사용, Mini-batch를 위한 설정)

 

6. callbacks

  • 콜백 함수 지정
  • 모델의 학습 중 특정 이벤트가 발생할 때 호출되는 함수
  • 모델의 학습을 모니터링하고 제어하는 데 사용
  • ModelCheckpoint -  주어진 에포크마다 모델의 가중치를 저장하는 콜백이다. 이를 사용하여 학습 중간에 모델의 성능이 가장 좋았던 시점의 가중치를 저장할 수 있다.
  • EarlyStopping - 설정된 조건에 따라 학습을 조기 종료하는 콜백이다. 예를 들어, 손실이 더 이상 개선되지 않을 때 학습을 중지할 수 있다.
  • TensorBoard - 학습 과정을 시각화하기 위한 콜백이다. TensorBoard는 학습 과정에서 발생한 메트릭 및 로그를 시각적으로 확인할 수 있는 도구다.
  • ReduceLROnPlateau - 검증 손실이 개선되지 않을 때 학습률을 동적으로 조정하는 콜백이다. 이를 사용하여 학습 속도를 조절하여 더 나은 성능을 얻을 수 있다.

 

 

fit 메서드는 return 값으로 history 객체를 반환하는데 이 객체의 history 속성을 호출하면 loss와 성능 점수를 확인할 수 있다.

history = model.fit(
    [input1,input2],target,
    validation_data = ([val1,val2],val_target),
    batch_size=16,
    epochs = 20,
    callbacks = [tensorboard]
)

hist = history.history
#epoch 별 loss
loss = hist['loss']

 

 

이제 모델을 가지고 예측을 해보자

 

 

Predict

모델을 학습하고 예측을 진행한다.

 

"모델 인스턴스.predict(X)"를 사용하여 예측을 진행한다. 

 

리턴 값으로 예측한 Y값을 준다.

import numpy as np
pre = model.predict([test1,test2])
print(pre.shape)
print(test_target.shape)

print("MSE:",np.mean(np.square(test_target - pre)))

 

 

 

 

complie&fit&predict.ipynb

Colaboratory notebook

colab.research.google.com

 

728x90
반응형