기본 콘텐츠로 건너뛰기

pytorch와 flask를 활용한 딥러닝 모델 서빙하기

pytorch와 flask를 활용한 딥러닝 모델 서빙하기

tensorflow 2.0을 활용해서 어떻게 서빙하는지 다뤄봤었는데, 요즘엔 pytorch를 사용하시는 분들도 많으니까 이번엔 pytorch를 서빙하는 방법에 대해서 설명드리려고 합니다!

이전 글들과 똑같이 mnist를 준비했고, 학습은 pytorch 공식 예제 참조하여 학습을 수행하였습니다.

아래 링크 참조하셔서 학습 진행해보시길 추천드려요!

https://github.com/pytorch/examples/tree/master/mnist

이전 포스팅에서 tensorflow 예제를 다룰 때는 pixel을 255로 나누어줬었는데, pytorch 예제를 보시면 0.1307과 0.3081이란 숫자를 활용해서 정규화를 해주는 것을 보실 수 있습니다.

이 부분은 mnist 데이터에서 전체 평균과 표준편차를 구하여 그 값을 활용하여 정규화를 수행하도록 한 것입니다.

pixel을 단순히 255로 나누는 것보다 평균과 표준편차를 활용하여 -1 ~ 1 사이의 값으로 정규화 해주는 것이 더 좋다고 하네요.

서빙이랑은 상관이 없으니 이쯤에서 넘어가도록 하겠습니다.

이제 pytorch 모델을 서빙하는 소스코드를 보도록 하겠습니다.

# flask_server.py import torch import numpy as np from torchvision import transforms from flask import Flask, jsonify, request from model import CNN model = CNN() model.load_state_dict(torch.load('mnist_model.pt'), strict=False) model.eval() normalize = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) app = Flask(__name__) @app.route('/inference', methods=['POST']) def inference(): data = request.json _, result = model.forward(normalize(np.array(data['images'], dtype=np.uint8)).unsqueeze(0)).max(1) return str(result.item()) if __name__ == '__main__': app.run(host='0.0.0.0', port=2431, threaded=False)

매우 간단하죠?

모델을 불러오고 정규화를 미리 정의해두고 그 후에는 요청이 들어올 때 마다 결과를 출력해서 반환하도록 구현하였습니다.

서버에 요청하는 소스코드는 아래와 같습니다.

# flask_test.py import json import requests import numpy as np from PIL import Image image = Image.open('test_image.jpg') pixels = np.array(image) headers = {'Content-Type':'application/json'} address = "http://127.0.0.1:2431/inference" data = {'images':pixels.tolist()} result = requests.post(address, data=json.dumps(data), headers=headers) print(str(result.content, encoding='utf-8'))

이미지를 불러와서 픽셀을 담아서 보내주면 끝!

이렇게 pytorch 모델까지 어떻게 서빙을 할 수 있는지 간단하게 살펴보았습니다.

하지만, 이렇게만 서빙한다고 서비스를 할 수 있는 건 절대 아니겠죠?

제일 먼저 병렬처리에 대해서 궁금해 하실 것 같네요.

flask_server.py에서 맨 밑에 줄에 threaded=True 옵션을 주면, 각 요청들이 각각의 쓰레드로 동작하면서 병렬처리가 가능하도록 flask에서 제공은 하고 있지만 pytorch에서는 그 기능을 사용할 경우에는 내부에서 데이터가 꼬이는 현상이 발생하게 됩니다. (Tensorflow에서는 문제가 없어서 threaded=True 옵션으로도 병렬처리가 가능은 합니다.)

그래서 쓰레드 방식보다는 프로세스를 여러개 띄우는 방식을 사용해야만 해요!

그런 부분에 대해서는 다음 포스팅에서 쓰레드와 프로세스의 차이, 그리고 파이썬에서의 병렬처리에 대해서도 간단하게 정리하고 어떤 방법을 활용해서 병렬처리를 해주면 좋을지 다음 포스팅에서 정리해보도록 하겠습니다.

이번 포스팅에서 활용된 전체 소스코드는 아래의 깃헙 주소로 가시면 모두 보실 수 있습니다!

오늘도 즐거운 딥러닝하세요!

https://github.com/hsh2438/mnist_serving_pytorch_flask.git

from http://seokhyun2.tistory.com/43 by ccl(A) rewrite - 2020-03-06 03:20:18

댓글

이 블로그의 인기 게시물

Coupang CS Systems 채용 정보: 쿠팡 운용 관리 시스템을 구축 하고...

Coupang CS Systems 채용 정보: 쿠팡 운용 관리 시스템을 구축 하고... Global Operation Technology는 상품을 고객에게 지연 없이 전달 될 수 있도록 하는 조직입니다. 1997년, 초창기 아마존에 입사한다고 상상해보세요. 그 당시 누구도 e-commerce 산업이, 아마존이라는 회사가 지금처럼 성장하리라고는 생각하지 못했을 것입니다. 하지만, 그 당시 아마존을 선택한 사람들은 e-commerce 산업을 개척했고, 아마존을 세계적인 회사로 성장시켰습니다. 2016년 '아시아의 아마존'으로 성장하고 있는 쿠팡, 당신에게 매력적인 선택이 아닐까요? Global Operation Technology: eCommerce에서 주문을 한 뒤 벌어지는 상황에 대해서 호기심을 가져보신 적이 있나요? Global Operation Technology는 상품을 고객에게 지연 없이 전달 될 수 있도록 하는 조직입니다. 매일 최첨단 소프트웨어 기술을 이용해 고객의 주문을 받고 상품을 어느 창고에서 출고 시킬지, 포장을 하나의 박스 또는 여러 개로 나눌 것인지, 어떤 배송 루트를 선택하고 어떻게 고객에게 배송 상태를 보여줄지 결정하는 시스템과 서비스를 개발 합니다. What Global Operations Technology does: CS and C-Returns System 적극적 고객서비스를 바탕으로 고객의 목소리를 통해 끊임없이 고객 에게 서비스를 제공하고 Andon 메커니즘을 통해 고객의 목소리를 회사 전체와 공유합니다. 그리고 고객 문제 해결과 구매 이후 벌어질 수 있는 고객 문제를 사전에 예방하기 위한 시스템 개발을 통해 미래의 상황을 예측 합니다. Tranportation System TSP (Traveling Salesman Problem) 와 같은 CS 최적화 관리 문제를 다룹니다.배송 물품의 실시간 추적, 3P 하드웨어와 소프트웨어를 통합, 각 배송 루트에 할당되는 물량 예측하고 T...

[ubuntu] FLASK_APP

[ubuntu] FLASK_APP Development/Debugging 🐞 FLASK_ENV=development FLASK_APP = app.py flask run zsh: command not found: FLASK_APP ✔️ FLASK_ENV=development FLASK_APP=app.py flask run 띄어쓰기를 해서 저런 오류를 출력할수도 있구나 😲 참고 : 108p에서 FLASK가 FKAS로 오타나있다. from http://hee-stories.tistory.com/18 by ccl(A) rewrite - 2020-03-24 17:20:11

[GCP] Argo로 Workflow 만들기

[GCP] Argo로 Workflow 만들기 사실 Production 레벨로 가지 않으면, ML개발에 Workflow를 사용할 일은 많지 않다. 대부분 샘플데이터로 전처리 한후 그 데이터를 공유해서 각자 모델을 개발하게 되는데, Production Level에서는 계속 새로운 데이터가 발생하기 때문에 데이터 수집부터 배포까지 하나의 파이프라인으로 관리해야할 필요성이 생긴다. Argo는 컨테이너 기반으르 파이프라인을 구성해주는 도구로 Kubeflow에서도 Workflow Orchestration은 Argo를 사용한다. Kubeflow Pipeline Overview Argo 설치 curl -sSL -o /usr/local/bin/argo https://github.com/argoproj/argo/releases/download/v2.2.1/argo-linux-amd64 chmod +x /usr/local/bin/argo Argo를 위와 같이 다운로드 받고, Controller와 UI를 kubectl을 통해 설치한다. GCP에서 kubectl의 설치는 아래를 따르면 된다. 터미널에서 Kubectl 사용하기 kubectl create ns argo kubectl apply -n argo -f https://raw.githubusercontent.com/argoproj/argo/v2.2.1/manifests/install.yaml Argo를 통해 간단한 'Hello World'예제를 실행해보자. 사용법은 아래와 같이 간단하다. submit은 지정된 yaml 파일을 workflow 만드는데 사용한다는 것이고 watch 파라미터는 외부의 yaml을 가져올 때 사용한다. argo submit --watch https://raw.githubusercontent.com/argoproj/argo/master/examples/hello-world.yaml 'argo list' 명령으로 실행되고 있는 argo wor...