[DL] Keras로 imagenet 모델 save(저장), load(불러오기) 본문

AI/Deep Learning

[DL] Keras로 imagenet 모델 save(저장), load(불러오기)

최재강 2021. 10. 1. 11:43

Keras에는 미리 학습된 모델(pretrained model)들을 제공해 준다. 따라서 제공해주는 모델들을 자신의 컴퓨터 환경에 저장하고 추후에 로드하여 사용할 수 있다. 이번 포스팅에서는 pretrained model을 local 환경에 save하고 load해오는 방법을 알아보자.

작업 환경

pip install tensorflow
pip install keras

필요한 라이브러리를 pip로 깔아주자.
나는 tensorflow 1.15.0, keras 2.1.5 버전에서 진행했다.

pretrained model list

keras에는 위와 같이 MobileNet처럼 Parameters수가 작은 모델 부터 VGG16같은 Heavy한 모델도 지원해주고 있다.

save model

save model을 위한 코드는 다음과 같다.

import tensorflow as tf
from tensorflow.keras.applications import mobilenet_v2

model_type = 'mobilenet_v2'

saved_model_dir = f'{model_type}_saved_model'

model = mobilenet_v2.MobileNetV2(weights='imagenet')
model.save(saved_model_dir)

tensorflow.keras.application에 mobilenet_v2를 포함한 위에서 언급한 표의 모델들을 호출하여 save할 수 있다. 위의 경우 mobilenet_v2만 impor하여 사용한 것이다.

model이라는 변수에 imagenet을 기반으로 pretrained된 mobilenet 모델을 넣어주고 save 함수를 호출하여 로컬에 저장했다. 이 때 경로의 경우 saved_model_dir 변수에서 설정해주었다.

load model

그럼 이제 load를 해보자.

from tensorflow.keras.models import load_model
model_type = 'mobilenet_v2'
model_path = f'{model_type}_saved_model'
model = load_model(model_path, compile=True)

model.summary()

load를 위해 keras의 load_model이라는 함수를 import 한다.

아까 save한 모델의 경로를 load_model 함수의 첫번째 인자로 넣어준다.

compile=True는 모델을 컴파일 해놓은 상태에서 load할 것인지 설정해주는 것이다 기본적으로 True가 기본 값이며 False로 설정 시 모델 구성을 수정할 수 있고 model.compile()을 통해 다시 컴파일 해야한다.

load한 후 모델을 summary, evaluate, predict 등 많은 작업을 진행해보자 !

여러 pretrained model 저장해보기

대학원에서 연구를 하면서 여러 모델을 사용해야 하는 일이 생겼다.
그래서 여러 모델을 import 해두고 model_type을 바꾸어 가며 필요한 모델을 주석해제하고 저장하였다. 필요하신 분들이 있을까 공유 드린다.

import tensorflow as tf
from tensorflow.keras.applications import ( 
    xception,
    vgg16,
    vgg19,
    resnet,
    resnet50,
    resnet_v2,
    inception_v3,
    inception_resnet_v2,
    mobilenet,
    densenet,
    nasnet,
    mobilenet_v2
)
model_type = 'xception'
saved_model_dir = f'{model_type}_saved_model'

models = {
    'xception':xception.Xception(weights='imagenet',include_top=False),
#     'vgg16':vgg16.VGG16(weights='imagenet'),
#     'vgg19':vgg19.VGG19(weights='imagenet'),
    'resnet50':resnet50.ResNet50(weights='imagenet'),
#     'resnet101':resnet.ResNet101(weights='imagenet'),
#     'resnet152':resnet.ResNet152(weights='imagenet'),
#     'resnet50_v2':resnet_v2.ResNet50V2(weights='imagenet'),
#     'resnet101_v2':resnet_v2.ResNet101V2(weights='imagenet'),
#     'resnet152_v2':resnet_v2.ResNet152V2(weights='imagenet'),
#     'resnext50':resnext.ResNeXt50(weights='imagenet'),
#     'resnext101':resnext.ResNeXt101(weights='imagenet'),
#     'inception_v3':inception_v3.InceptionV3(weights='imagenet'),
#     'inception_resnet_v2':inception_resnet_v2.InceptionResNetV2(weights='imagenet'),
#     'mobilenet':mobilenet.MobileNet(weights='imagenet'),
#     'densenet121':densenet.DenseNet121(weights='imagenet'),
#     'densenet169':densenet.DenseNet169(weights='imagenet'),
#     'densenet201':densenet.DenseNet201(weights='imagenet'),
#     'nasnetlarge':nasnet.NASNetLarge(weights='imagenet'),
#     'nasnetmobile':nasnet.NASNetMobile(weights='imagenet'),
#     'mobilenet_v2':mobilenet_v2.MobileNetV2(weights='imagenet')
}

model = models[model_type]
model.save(saved_model_dir)

 

 

참고

'AI > Deep Learning' 카테고리의 다른 글

Tensorflow MLFlow 사용해보기  (0) 2023.06.26
[DL] 딥러닝 추론이란?  (4) 2021.09.20
Comments