728x90
반응형
Tensorflow에서 모델을 만드는 만드는 방법은 크게 3가지가 존재한다.
1. Sequential API
- Layer를 선형으로 쌓아 구성하는 방법.
- 간단한 모델을 만들 때 유용.
- 단일 입력, 히든, 출력 계층만 가능함.
- Layer sharing을 할 수 없다.
- 비선형 토폴로지를 사용할 수 없다.
- LeNet, AlexNet, VGGNet
2. Functional API
- 권장되는 방법.
- 다중 입력, 히든, 출력 계층 가능.
- Layer sharing 가능.
- 비선형 토폴리지 사용 가능.
- ResNet, GoogLeNet/Inception, Xception, SqueezeNet
3. Subclassing
- 모델, 계층 및 교육 프로세스를 완벽하게 제어해야 하는 고급 개발자를 위한 방법.
- 모델을 정의하는 사용자 정의 클래스를 생성하여 사용.
- 가장 유연한 방법이나 비용이 많이(활용 및 이해 어렵고 디버깅이 어려움, HDF5 형식으로 저장할 수 없음) 든다.
Sequential API
tf.keras.models.Sequential을 사용하여 레이어를 쌓는다.
import tensorflow as tf
from tensorflow.keras.utils import plot_model
input_shape = (10,10)
output_shape = (1,)
seq_model = tf.keras.models.Sequential([
# 여기에 모델을 쌓는다.
tf.keras.layers.Dense(64,activation='relu',input_shape=input_shape),
tf.keras.layers.Dense(128,activation='relu'),
tf.keras.layers.Dense(128,activation='relu'),
tf.keras.layers.Dense(128,activation='relu'),
tf.keras.layers.Dense(output_shape[0],activation='softmax'),
])
plot_model(seq_model, show_shapes=True) # 모델 시각화
Functional API
tf.keras.layers의 Layer들을 사용하여 레이어를 쌓는다.
inputs = tf.keras.Input(shape=input_shape)
x = tf.keras.layers.Dense(64,activation='relu')(inputs)
x = tf.keras.layers.Dense(128,activation='relu')(x)
x = tf.keras.layers.Dense(128,activation='relu')(x)
x = tf.keras.layers.Dense(128,activation='relu')(x)
outputs = tf.keras.layers.Dense(1,activation='relu')(x)
func_model1 = tf.keras.Model(inputs=inputs,outputs=outputs)
plot_model(func_model1, show_shapes=True)
비선형 토폴로지 구성(다중 입력, 출력, 히든 Layer) tf.keras.layers.Concatenate()로 합쳐주기
# Multi
input1= tf.keras.Input(shape=input_shape)
x1 = tf.keras.layers.Dense(64,activation='relu')(input1)
x1 = tf.keras.layers.Dense(128,activation='relu')(x1)
x1 = tf.keras.layers.Dense(128,activation='relu')(x1)
x1 = tf.keras.layers.Dense(128,activation='relu')(x1)
input2= tf.keras.Input(shape=(10,1))
x2 = tf.keras.layers.Dense(64,activation='relu')(input2)
x2 = tf.keras.layers.Dense(128,activation='relu')(x2)
# Concat, shape의 [0]이 같아야함
concat = tf.keras.layers.Concatenate()([x1,x2])
concat = tf.keras.layers.Dense(128,activation='relu')(concat)
# output
output1 = tf.keras.layers.Dense(128,activation='relu')(concat)
output2 = tf.keras.layers.Dense(128,activation='relu')(concat)
multi_model = tf.keras.Model(inputs=[input1,input2],outputs=[output1,output2])
plot_model(multi_model, show_shapes=True)
Subclassing
Python Class를 정의해 모델을 정의한다.
사용되는 메서드
1. __init__()
- 클래스 인스턴스를 초기화
- 메서드를 정의하거나 추가적인 속성이나 초기화 작업 수행
- Layer를 정의한다.
2. super()
- 부모 클래스를 호출(tf.keras.Model)
- tf.keras.Model에 정의된 메서드를 사용하기 위해 호출
3. call()
- tf.keras.Model에 정의된 메서드
- 사용자 정의 모델 클래스에서 오버라이딩하여 모델의 호출 동작을 구현
- 입력 데이터를 받아 Forward Propation을 수행하고 출력을 반환
4. build()
- tf.keras.Model에 정의된 메서드
- weight를 동적으로 생성하는 역할
class MyModel(tf.keras.Model):
def __init__(self,input_shape):
super(MyModel,self).__init__()
self.shape = input_shape
self.d1 = tf.keras.layers.Dense(64, activation='relu',name='dense1')
self.d2 = tf.keras.layers.Dense(128, activation='relu',name='dense2')
self.concat = tf.keras.layers.Concatenate(name='concat1')
self.d3 = tf.keras.layers.Dense(128, activation='relu',name='dense3')
self.d4 = tf.keras.layers.Dense(128, activation='relu',name='dense4')
self.d5 = tf.keras.layers.Dense(128, activation='relu',name='dense5')
self.out = tf.keras.layers.Dense(1,activation='softmax',name='output')
def call(self,inputs:tuple):
input1, input2 = inputs
x1 = self.d1(input1)
x1 = self.d2(x1)
x1 = self.d3(x1)
x2 = self.d4(input2)
x2 = self.d5(input2)
concat = self.concat([x1,x2])
return self.out(concat)
def build_graph(self): # 시각화를 위한 코드
x = tf.keras.Input(shape=(self.shape))
return tf.keras.Model(inputs=[x], outputs=self.call((x,x)))
mmd = MyModel(input_shape)
mmd.build(input_shape=[input_shape,input_shape]) # subclassing은 build를 해줘야함
plot_model(mmd.build_graph(), show_shapes=True)
Create Model.ipynb
Colaboratory notebook
colab.research.google.com
728x90
반응형
'AI > AI 라이브러리' 카테고리의 다른 글
[Tensorflow] Callback 알아보기 (0) | 2024.01.10 |
---|---|
[Tensorflow] 모델 컴파일, 학습, 예측 (0) | 2024.01.09 |
[Tensorflow] Tensorflow 입문 (0) | 2024.01.08 |
[Pytorch] 연습 데이터로 라이브러리 익숙해지기2 (0) | 2024.01.03 |
[Pytorch] 연습 데이터로 라이브러리 익숙해지기1 (0) | 2024.01.02 |