
저번 포스트에서는 Tensorflow를 사용해 모델을 만드는 방법을 알아보았다.
이제 모델을 만들었으니 모델을 사용하기 위해 컴파일, 학습, 예측의 순서대로 알아보자.
Complie
Tensorflow에서 모델 complie은 "모델 인스턴스.comple()"을 사용한다. 이 때 여러가지 속성을 적용할 수 있다.
주요 속성은 다음과 같다.
1. optimizer
- 최적화기
- 모델의 가중치를 업데이트하는 최적화 알고리즘을 지정
- adam, sgd, rmsprop 등
2. loss
- 손실함수
- 모델의 학습 중에 최소화될 손실 함수를 지정
- categorical_crossentropy, binary_crossentropy, mse 등
3. metrics
- 평가지표
- 모델의 성능을 평가하기 위한 지표를 지정
- accuracy, precision, recall 등
# 모델 컴파일
model.compile(optimizer='sgd',loss='mse')
이렇게 compile을 하면 이제 학습(fitting)을 할 수 있는 상태가 된다.
Fitting
모델 학습을 해보자.
complie 된 모델에 X, Y, epochs, (선택 : batch_size, validation_data, callbacks)를 넣어 모델을 학습 한다.
주요 속성은 다음과 같다
1.X
- 입력 데이터 셋
- 훈련 데이터 셋을 넣음.
2. Y
- target 데이터 셋(지도 학습의 경우)
- 훈련 데이터 셋을 넣음
3. epochs
- 훈련의 반복의 수
4. validation_data
- 검증 데이터 셋을 지정
5. batch_size
- 한번 epoch에서 사용되는 데이터의 개수(설정 안하면 전체 데이터 사용, Mini-batch를 위한 설정)
6. callbacks
- 콜백 함수 지정
- 모델의 학습 중 특정 이벤트가 발생할 때 호출되는 함수
- 모델의 학습을 모니터링하고 제어하는 데 사용
- ModelCheckpoint - 주어진 에포크마다 모델의 가중치를 저장하는 콜백이다. 이를 사용하여 학습 중간에 모델의 성능이 가장 좋았던 시점의 가중치를 저장할 수 있다.
- EarlyStopping - 설정된 조건에 따라 학습을 조기 종료하는 콜백이다. 예를 들어, 손실이 더 이상 개선되지 않을 때 학습을 중지할 수 있다.
- TensorBoard - 학습 과정을 시각화하기 위한 콜백이다. TensorBoard는 학습 과정에서 발생한 메트릭 및 로그를 시각적으로 확인할 수 있는 도구다.
- ReduceLROnPlateau - 검증 손실이 개선되지 않을 때 학습률을 동적으로 조정하는 콜백이다. 이를 사용하여 학습 속도를 조절하여 더 나은 성능을 얻을 수 있다.
fit 메서드는 return 값으로 history 객체를 반환하는데 이 객체의 history 속성을 호출하면 loss와 성능 점수를 확인할 수 있다.
history = model.fit(
[input1,input2],target,
validation_data = ([val1,val2],val_target),
batch_size=16,
epochs = 20,
callbacks = [tensorboard]
)

hist = history.history
#epoch 별 loss
loss = hist['loss']

이제 모델을 가지고 예측을 해보자
Predict
모델을 학습하고 예측을 진행한다.
"모델 인스턴스.predict(X)"를 사용하여 예측을 진행한다.
리턴 값으로 예측한 Y값을 준다.
import numpy as np
pre = model.predict([test1,test2])
print(pre.shape)
print(test_target.shape)
print("MSE:",np.mean(np.square(test_target - pre)))

complie&fit&predict.ipynb
Colaboratory notebook
colab.research.google.com
'AI > AI 라이브러리' 카테고리의 다른 글
[Tensorflow] Optimizer / Loss Function / Metric (0) | 2024.01.11 |
---|---|
[Tensorflow] Callback 알아보기 (0) | 2024.01.10 |
[Tensorflow] 모델 만들기 (0) | 2024.01.09 |
[Tensorflow] Tensorflow 입문 (0) | 2024.01.08 |
[Pytorch] 연습 데이터로 라이브러리 익숙해지기2 (0) | 2024.01.03 |