기본 콘텐츠로 건너뛰기

[GCP] Flask로 TF 2.0 MNIST 모델 서빙하기

[GCP] Flask로 TF 2.0 MNIST 모델 서빙하기

Google Cloud Platform

우선 TensorFlow 2.0을 설치하자. 머신에 직접 설치하거나 도커를 다운받아 사용, 혹은 구글 colab을 활용( https://www.tensorflow.org/install)하면 되는데, TensorFlow에서 권장하는대로 머신에 VirtualEnv를 활용해서 설치하자

( https://www.tensorflow.org/install/pip). 설치하는 김에 Flask도 같이 설치해보자. Compute Machine 하나를 생성(크게 부담 없는 예제라 g1 instance)하고, SSH를 연결하여 실행하면 된다.

$ sudo apt update $ sudo apt install python3-dev python3-pip $ sudo pip3 install -U virtualenv # 굳이 system-wide로 flask를 설치할 필요는 없지만 그렇게 했다. $ sudo pip3 install flask $ sudo pip3 install flask-restful # virtualenv 환경에서 tensorflow 2.0 설치 $ virtualenv --system-site-packages -p python3 ./venv $ source ./venv/bin/activate # sh, bash, ksh, or zsh (venv) $ pip install --upgrade pip (venv) $ pip install --upgrade tensorflow

모든 환경이 마련되었으니, 우선 MNIST 모델을 TF 2.0으로 Training하여 모델을 Save 해 두자(tf_mnist_train.py). 대략 99% 이상 정확도가 나온다!

import tensorflow as tf import numpy as np # 학습 데이터 load ((train_data, train_label), (eval_data, eval_label)) = tf.keras.datasets.mnist.load_data() # data를 정규화하여 28x28로 reshape train_data=train_data/np.float32(255) train_data=train_data.reshape(60000, 28, 28, 1) train_data.shape eval_data = eval_data/np.float32(255) eval_data = eval_data.reshape(10000, 28, 28, 1) eval_data.shape from tensorflow.keras import models # CNN으로 모델 생성 model =models.Sequential() model.add(tf.keras.layers.Conv2D(32, (5,5), padding='same', activation='relu', input_shape=(28,28,1))) model.add(tf.keras.layers.MaxPooling2D((2,2))) model.add(tf.keras.layers.Conv2D(64, (5,5), activation='relu')) model.add(tf.keras.layers.MaxPooling2D((2,2))) model.add(tf.keras.layers.Conv2D(64, (5,5), activation='relu')) model.add(tf.keras.layers.Flatten()) model.add(tf.keras.layers.Dense(64, activation='relu')) model.add(tf.keras.layers.Dense(10, activation='softmax')) model.summary() # graph를 생성하고 training model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) model.fit(train_data, train_label, epochs=5) test_loss, test_acc = model.evaluate(eval_data, eval_label, verbose=2) test_acc # save the model. TF 2.0에서는 experimental 대신 save_model만 하면됨 model_dir = "/tmp/tfkeras_mnist" tf.keras.experimental.export_saved_model(model, model_dir)

python으로 위 파일을 실행하면 모델이 지정된 곳에 저장된다. 우선 해당 모델을 불러서 제대로 예측하는지 확인해 보면, 99% 정확도니 예측은 비교적 정확해야 한다.

import tensorflow as tf import numpy as np # eval data 불러오고 ((train_data, train_label), (eval_data, eval_label)) = tf.keras.datasets.mnist.load_data() eval_data = eval_data/np.float32(255) eval_data = eval_data.reshape(10000, 28, 28, 1) # 저장한 모델 불러 온뒤 model_dir = "/tmp/tfkeras_mnist" new_model = tf.keras.experimental.load_from_saved_model(model_dir) new_model.summary() # 그래프를 형성하고, new_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) # 임의의 위치에 있는 MNIST 숫자를 하나 읽어서 예측 random_idx = np.random.choice(eval_data.shape[0]) test_data = eval_data[random_idx].reshape(1, 28, 28, 1) res = new_model.predict(test_data) # 제대로 학습되었는지 확인 print ("predict: {}, original: {}".format(np.argmax(res), eval_label[random_idx]))

대체로 결과는 아래와 같다.

(venv) $ python tf_test_mnist.py predict: 6, original: 6

모델은 제대로 만들어졌으니, 이제 Flask를 이용해 웹으로 서빙해 보자. virtualenv를 user 계정으로 수행하고 있으므로 5000번 이상의 포트를 사용해야 한다. 이를 위해 GCP에서 방화벽 규칙을 하나 만들어 주자.

(네트워킹 → VPC 네트워크 → 방화벽 규칙)

Flask 실행은 python 명령으로 직접 py파일을 실행하거나, FLASK_APP으로 지정된 파일을 flask run으로 수행할 수 있다. flask에서 어떻게 tensorflow를 불러오는 지 몰라, 일단 python 명령으로 수행할 것이다. 이전에 test한 모듈을 http get으로 읽어오는 부분에 넣어주기만 하면 된다.

from flask import Flask, render_template import flask_restful import tensorflow as tf import numpy as np # 데이터를 읽어들이고 ((train_data, train_label), (eval_data, eval_label)) = tf.keras.datasets.mnist.load_data() eval_data = eval_data/np.float32(255) eval_data = eval_data.reshape(10000, 28, 28, 1) # 저장해 두었던 모델을 읽어들인 후 model_dir = "/tmp/tfkeras_mnist" new_model = tf.keras.experimental.load_from_saved_model(model_dir) new_model.summary() #그래프를 생성하고 new_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) # Flask Restful API로 읽어들일 APP을 지정. app = Flask(__name__) api = flask_restful.Api(app) # Flask가 사용할 리소스는 Test 클래스. # get 함수가 HTTP Get으로 결과를 읽어들임 class Test(flask_restful.Resource): def get(self): random_idx = np.random.choice(eval_data.shape[0]) random_idx test_data = eval_data[random_idx].reshape(1, 28, 28, 1) res = new_model.predict(test_data) return { 'predict': np.argmax(res).tolist(), 'answer': eval_label[random_idx].tolist() } # Test 클래스를 리소스로 추가. 두번째 인자는 파일의 위치. # 우리는 ~/venv/tf_mnist 현재 디렉토리에서 읽을 것이므로 '/' api.add_resource(Test, '/') # 사용하는 포트는 5000번 if __name__ == "__main__": app.run(host="0.0.0.0", port=5000)

모든 작업이 끝났으니 서빙을 시작해 보자.

(venv) ryu_gcloud2@flask-test:~/venv/tf_mnist$ python3 app.py * Serving Flask app "app" (lazy loading) * Environment: production WARNING: This is a development server. Do not use it in a production deployment. Use a production WSGI server instead. * Debug mode: off * Running on http://0.0.0.0:5000/ (Press CTRL+C to quit)

서버가 잘 수행되었으니 브라우저로 테스트 해 보면 된다. 외부에서 접근되는 것이니, GCP Compute Instance의 외부 IP 주소를 하나 할당(이번 예제의 경우엔 아래 35.223.49.131)받고 그 주소로 접근(포트는 5000번)하면 된다.

Flask로 학습 모델 결과를 제대로 서빙하고 있다.

from http://ml-cloud.tistory.com/4 by ccl(A)

댓글

이 블로그의 인기 게시물

스프링 프레임워크(Spring Framework)란?

스프링 프레임워크(Spring Framework)란? "코드로 배우느 스프링 웹 프로젝트"책을 개인 공부 후 자료를 남기기 위한 목적이기에 내용 상에 오류가 있을 수 있습니다. '스프링 프레임워크'가 무엇인지 말 할 수 있고, 해당 프레임워크의 특징 및 장단점을 설명할 수 잇는 것을 목표로합니다. 1. 프레임워크란? 2. 스프링 프레임워크 "뼈대나 근간을 이루는 코드들의 묶음" Spring(Java의 웹 프레임워크), Django(Python의 웹 프레임워크), Flask(Python의 마이크로 웹 프레임워크), Ruby on rails(Ruby의 웹 프레임워크), .NET Framework, Node.js(Express.js 프레임워크) 등등. 프레임워 워크 종류 : 3. 개발 시간을 단축할 수 있다. 2. 일정한 품질이 보장된 결과물을 얻을 수 있다. 1. 실력이 부족한 개발자라 허다러도 반쯤 완성한 상태에서 필요한 부분을 조립하는 형태의 개발이 가능하다. 프레임워크를 사용하면 크게 다음 3가지의 장점 이 있습니다. 프레임워크 이용 한다는 의미 : 프로그램의 기본 흐름이나 구조를 정하고, 모든 팀원이 이 구조에 자신의 코드를 추가하는 방식으로 개발 한다. => 이러한 상황을 극복하기 위한 코드의 결과물이 '프레임워크' 입니다. 개발자는 각 개개인의 능력차이가 크고, 따라서 개발자 구성에 따라서 프로젝트의 결과 차이가 큽니다. 2. 스프링 프레임워크(Spring Framework) 자바 플랫폼을 위한 오픈 소스 애플리케이션 스프링의 다른 프레임워크와 가장 큰 차이점은 다른 프레임워크들의 포용 입니다. 이는 다시말해 기본 뼈대를 흔들지 않고, 여러 종류의 프레임워크를 혼용해서 사용할 수 있다는 점입니다. 대한민국 공공기관의 웹 서비스 개발 시 사용을 권장하고 있는 전자정부 표준프레임워크 이다. 여러 프레임워크들 중 자바(JAV...

Dummy to resolve the flask problems

Dummy to resolve the flask problems This post is about flask problems that I struggled with. Hope you this is useful things when you taste it. Issue : How to deploy a flask application on Apache2 Resolve : As you know, flask is a micro framework. It can be handled on Apache2 using WSGI module. See the reference. Reference: https://www.digitalocean.com/community/tutorials/how-to-deploy-a-flask-application-on-an-ubuntu-vps Issue : Flask caused ERR_CONNECTION_ABORTED on POST Resolve : There are lots issues for this problem in principle. It caused when browser keep sending some buffer but server doesn't want to receive. My case is like this (submit.html) (submit.py) @bp.route('/submit', methods=["GET", "POST"]) def submit(): return render_template("submit.html") This kinda skel code to explain this. In flask case, this can be caused when it runs as develop server such as run...