Pythonで決定木分類器をフルスクラッチで実装してみた
機械学習モデルをスクラッチから実装しようと思い立ったので、第一歩として決定木分類器(Decision Tree Classifer) をPythonで実装してみました。RandomForestやXGBoostなどといった決定木系の機械学習アルゴリズムを使う場面も多いと思うので、その基礎となる決定木分類器について実装をしながら仕組みを理解していきます。
決定木とは
決定木とは木構造を用いて分類または回帰のモデルを構築する手法です。変数を基準に条件分岐を行い、末端のノード(葉ノード)にたどりつくまで分岐をたどることで、分類クラスや回帰の結果を決定するものです。木が深くなるほど過学習が起きやすくなってしまうため、学習時には少ない分岐でより高い識別能力を持つ木を構築することが求められます。
分割の基準となる説明変数とその閾値が優先度の高いものから順に並ぶため、分類の解釈がしやすいといったメリットがあります。また、変数の重要度 (Feature importance) を算出することで、どの説明変数が識別に寄与しているかということを定量的に把握できるという特長もあります。
決定木の学習を実装
Pythonを使って決定木の学習をスクラッチから実装していきます。スクラッチとはいえnumpyを使用しているという点についてはご容赦ください。決定木アルゴリズムにはいくつかの種類がありますが、今回は2分木を前提とし、ノードの分割条件を求める評価基準としてGini係数を使うCARTアルゴリズムで実装を進めます。実装はPython3で、numpyおよびscikit-learnがインストールされていることを前提としています。手早く環境構築をしたい方はanacondaを使いましょう。
Gini係数による不純度の計算
決定木を構築する際の最も重要な指標として不純度 (impurity) というものがあります。これは、あるノードに分岐したサンプルのクラスがどれほど散らばっているかということを表します。分岐させた結果、あるノードに1種類のクラスだけが分類された場合(きれいに分割できた場合)、そのノードの不純度は0となります。このように、不純度を求める方法としては交差エントロピーとGini係数がありますが、今回はCARTアルゴリズムに従ってGini係数を使います。
ノードtにおけるクラスの割合をとすると、Gini係数は次の式を使って求めることができます。
よくわからないので具体例を使って説明をしていきます。
文系・理系のサンプルが各200人ずついるとし、数学または国語の点数でこれらを分類したいとします。
より識別性能が高いものを木の上のほうに配置したいため、分岐条件AとBの不純度をそれぞれ求めます。
分割条件Aの ’yes’ グループのGini係数は
【Gini(yes)】
分割条件Aの ’no’ グループのGini係数は
【Gini(no)】
分割ルールAによる不純度は次のよう二求めることができます。
同様に、分割条件Bの不純度を計算すると0.49となりました。
分割条件Aのほうが不純度が低いため、分割条件Aのほうが識別能力が高い分岐条件であることがわかります。Gini係数を計算することで、ノードを分割する最適な説明変数および閾値を求めていき、決定木を構築していきます。
Gini係数を求める関数は次のように実装しました。
def gini_score(data, target, feat_idx, threshold): gini = 0 sample_num = len(target) div_target = [target[data[:, feat_idx] >= threshold], target[data[:, feat_idx] < threshold]] for group in div_target: score = 0 classes = np.unique(group) for cls in classes: p = np.sum(group == cls)/len(group) score += p * p gini += (1- score) * (len(group)/sample_num) return gini
ノードを分割する最適な説明変数と閾値の選択
Gini係数が小さい分岐条件を見つけることで、識別能力の高い木が構築できるということがわかりました。今回は、最適な分岐条件をみつけるためにすべての説明変数について閾値を変化させていき、Gini係数が最小となる説明変数と閾値のペアをノードの分岐条件とする方針で実装していきます。
ノードに流れてくるdataおよびそれらに付与されたクラスラベル targetを渡すことで、変数giniが最小となるときの閾値best_thrsおよび説明変数best_fを返り値として得ることができます。
def search_best_split(data, target): features = data.shape[1] best_thrs = None best_f = None gini = None gini_min = 1 for feat_idx in range(features): values = data[:, feat_idx] for val in values: gini = gini_score(data, target, feat_idx, val) if gini_min > gini: gini_min = gini best_thrs = val best_f = feat_idx return gini_min, best_thrs, best_f
再帰関数を使った木の構築
木の構築方法はいくつか方法が考えられますが、ここではオブジェクトと再帰関数を使った実装を行います。まずDecisionTreeNodeクラスを定義し、splitメソッドを用いることでノードに対応するDecisionTreeNodeの子オブジェクトを再帰的に生成していきます。停止条件としては、以下の2通りを定義しています。
- ノードのGini係数が0になったとき(ノードに1つのクラスのサンプルしかないため、以降分岐の必要がない)
- 木の深さがあらかじめ定義したmax_depthに達したとき
各ノードはleftとrightというクラス変数を持ち、ここに子ノードのオブジェクトが格納されていくため、仮想的に木を構築することができます。
def split(self, depth): self.depth = depth self.gini_min, self.threshold, self.feature = search_best_split(self.data, self.target) print('Depth: {}, Sep at Feature: {},Threshold: {}, Label: {}'.format(self.depth, self.feature, self.threshold, self.label)) if self.depth == self.max_depth or self.gini_min == 0: return idx_left = self.data[:, self.feature] >= self.threshold idx_right = self.data[:, self.feature] < self.threshold self.left = DecisionTreeNode(self.data[idx_left], self.target[idx_left], self.max_depth) self.right = DecisionTreeNode(self.data[idx_right], self.target[idx_right], self.max_depth) self.left.split(self.depth +1) self.right.split(self.depth +1)
DecisionTreeNodeクラス全体は次のように実装しました。
class DecisionTreeNode(object): def __init__(self, data, target, max_depth): self.left = None self.right = None self.max_depth = max_depth self.depth = None self.data = data self.target = target self.threshold = None self.feature = None self.gini_min = None self.label = np.argmax(np.bincount(target)) def split(self, depth): self.depth = depth self.gini_min, self.threshold, self.feature = search_best_split(self.data, self.target) print('Depth: {}, Sep at Feature: {},Threshold: {}, Label: {}'.format(self.depth, self.feature, self.threshold, self.label)) if self.depth == self.max_depth or self.gini_min == 0: return idx_left = self.data[:, self.feature] >= self.threshold idx_right = self.data[:, self.feature] < self.threshold self.left = DecisionTreeNode(self.data[idx_left], self.target[idx_left], self.max_depth) self.right = DecisionTreeNode(self.data[idx_right], self.target[idx_right], self.max_depth) self.left.split(self.depth +1) self.right.split(self.depth +1) def predict(self, data): if self.gini_min == 0.0 or self.depth == self.max_depth: return self.label else: if data[self.feature] > self.threshold: return self.left.predict(data) else: return self.right.predict(data)
決定木のテストを実装
さきほど作成したDecisionTreeNodeクラスに、テスト時の予測メソッドを実装しました。各ノードのオブジェクトには、best_splitで求めた最適な説明変数と閾値が格納されているためこれを使用します。先ほどの停止条件と同様に、Gini係数が0または深さがmax_depthに達した場合は末端ノード(葉ノード)となるため、ここでラベルを返します。これはノードに含まれているサンプルのうち最も多いクラスを返しており、これが予測ラベルとなります。葉ノードに達していない場合は、閾値に従って分岐を進んでいき、末端に達した時点で予測ラベルを返すという実装になります。
def predict(self, data): if self.gini_min == 0.0 or self.depth == self.max_depth: return self.label else: if data[self.feature] > self.threshold: return self.left.predict(data) else: return self.right.predict(data)
自作Pythonコード
全体のPythonコードは次の通りです。scikit-learnのメソッドにならって、fitとpredictを持つクラスDesicionTreeClassifierを定義しました。
import numpy as np from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split def gini_score(data, target, feat_idx, threshold): gini = 0 sample_num = len(target) div_target = [target[data[:, feat_idx] >= threshold], target[data[:, feat_idx] < threshold]] for group in div_target: score = 0 classes = np.unique(group) for cls in classes: p = np.sum(group == cls)/len(group) score += p * p gini += (1- score) * (len(group)/sample_num) return gini def search_best_split(data, target): features = data.shape[1] best_thrs = None best_f = None gini = None gini_min = 1 for feat_idx in range(features): values = data[:, feat_idx] for val in values: gini = gini_score(data, target, feat_idx, val) if gini_min > gini: gini_min = gini best_thrs = val best_f = feat_idx return gini_min, best_thrs, best_f class DecisionTreeNode(object): def __init__(self, data, target, max_depth): self.left = None self.right = None self.max_depth = max_depth self.depth = None self.data = data self.target = target self.threshold = None self.feature = None self.gini_min = None self.label = np.argmax(np.bincount(target)) def split(self, depth): self.depth = depth self.gini_min, self.threshold, self.feature = search_best_split(self.data, self.target) print('Depth: {}, Sep at Feature: {},Threshold: {}, Label: {}'.format(self.depth, self.feature, self.threshold, self.label)) if self.depth == self.max_depth or self.gini_min == 0: return idx_left = self.data[:, self.feature] >= self.threshold idx_right = self.data[:, self.feature] < self.threshold self.left = DecisionTreeNode(self.data[idx_left], self.target[idx_left], self.max_depth) self.right = DecisionTreeNode(self.data[idx_right], self.target[idx_right], self.max_depth) self.left.split(self.depth +1) self.right.split(self.depth +1) def predict(self, data): if self.gini_min == 0.0 or self.depth == self.max_depth: return self.label else: if data[self.feature] > self.threshold: return self.left.predict(data) else: return self.right.predict(data) class DesicionTreeClassifier(object): def __init__(self, max_depth): self.max_depth = max_depth self.tree = None def fit(self, data, target): initial_depth = 0 self.tree = DecisionTreeNode(data, target, self.max_depth) self.tree.split(initial_depth) def predict(self, data): pred = [] for s in data: pred.append(self.tree.predict(s)) return np.array(pred) if __name__ == '__main__': iris = load_iris() data = iris.data target = iris.target X_train, X_test, y_train, y_test = train_test_split(data, target, test_size=0.3, random_state=0) clf = DesicionTreeClassifier(max_depth=3) clf.fit(X_train, y_train) y_pred = clf.predict(X_test) score = sum(y_pred == y_test)/float(len(y_test)) print('Classification accuracy: {}'.format(score))
自作コードを使った分類結果
フルスクラッチで書いたコードを使って、irisデータセットをもとに学習および評価を行いました。
学習した決定木は次のようになりました。可視化していないのでわかりにくいですが、深さ1の右側の木(一番下の行)はそれ以降分岐がないためGini係数が0であったことがわかります。このときの予測ラベルは0番のクラスです。また、逆側のブランチは深さ3まで伸びています。
Depth: 0, Sep at Feature: 2,Threshold: 3.0, Label: 2 Depth: 1, Sep at Feature: 2,Threshold: 5.0, Label: 2 Depth: 2, Sep at Feature: 2,Threshold: 5.1, Label: 2 Depth: 3, Sep at Feature: 0,Threshold: 6.5, Label: 2 Depth: 3, Sep at Feature: 0,Threshold: 6.7, Label: 2 Depth: 2, Sep at Feature: 3,Threshold: 1.7, Label: 1 Depth: 3, Sep at Feature: 1,Threshold: 3.2, Label: 2 Depth: 3, Sep at Feature: 0,Threshold: 5.0, Label: 1 Depth: 1, Sep at Feature: 0,Threshold: 4.8, Label: 0
最後に、正解率を求めることで分類器が機能していることを確認しました。
学習データ:テストデータ=7:3、max_depth=3としたときの識別精度は以下の通りでした。
Classification accuracy: 0.977
いい感じに分類できていることがわかりますね。
Pythonを使ってカメラ映像をプレビュー表示しながら動画として保存する
画像を扱う仕事をしていると、カメラを使って自ら画像を撮影しなければいけない場面がありますよね。今回はPCの内臓カメラやUSBカメラを使って、カメラ映像を動画として保存するコードを実装したのでご紹介します。
Pythonソースコード
引数でいろいろな設定をできるようにしています。
import argparse import cv2 import datetime import os def set_arguments(): parser = argparse.ArgumentParser() parser.add_argument('device_id', type=int, help="通常は、 0: 内蔵カメラ, 1: USBカメラ") parser.add_argument('limit', type=int, help="撮影フレーム数") parser.add_argument('-f', '--fps', default=30, help="撮影FPS") parser.add_argument('-o', '--output_directory', help="出力先ディレクトリ") parser.add_argument('-p', '--preview', type=int, default=1, help="プレビュー画面の表示有無 0:表示しない 1:表示する(デフォルト)") return parser.parse_args() def get_outputpath(output_directory): now = datetime.datetime.now() file_name = '{}.avi'.format(now.strftime('%Y%m%d_%H%M%S')) if output_directory is None: output_path = file_name else: os.makedirs(output_directory, exist_ok=True) output_path = os.path.join(output_directory, file_name) return output_path def set_camera(device_id, fps): fps = int(fps) camera = cv2.VideoCapture(device_id) camera.set(cv2.CAP_PROP_FRAME_WIDTH, 640) camera.set(cv2.CAP_PROP_FRAME_HEIGHT, 480) camera.set(cv2.CAP_PROP_FPS, fps) if camera is None: raise Exception("Camera not found. Check device id.") return camera def capture(camera, output_filepath, show_preview, limit = None): frame_number = 0 fourcc = cv2.VideoWriter_fourcc(*'MJPG') ret, frame = camera.read() fps = camera.get(cv2.CAP_PROP_FPS) height = camera.get(cv2.CAP_PROP_FRAME_HEIGHT) width = camera.get(cv2.CAP_PROP_FRAME_WIDTH) writer = cv2.VideoWriter(output_filepath, fourcc, fps, (int(width), int(height))) while(ret): frame_number += 1 writer.write(frame) if show_preview: cv2.imshow("preview", frame) if cv2.waitKey(int(1 / fps * 1000)) == 27: # ESC Key break if limit is not None and frame_number >= limit: break ret, frame = camera.read() if __name__ == '__main__': args = set_arguments() output_path = get_outputpath(args.output_directory) camera = set_camera(args.device_id, args.fps) capture(camera, output_path, args.preview, args.limit)
カメラ映像を動画に保存する
プログラムの使い方について説明します。
引数の渡し方は-hオプションで見ることができます。
python camera_recorder.py -h
こんな感じです。
usage: camera_recorder.py [-h] [-f FPS] [-o OUTPUT_DIRECTORY] [-p PREVIEW] device_id limit positional arguments: device_id 通常は、 0: 内蔵カメラ, 1: USBカメラ limit 撮影フレーム数 optional arguments: -h, --help show this help message and exit -f FPS, --fps FPS 撮影FPS -o OUTPUT_DIRECTORY, --output_directory OUTPUT_DIRECTORY 出力先ディレクトリ -p PREVIEW, --preview PREVIEW プレビュー画面の表示有無 0:表示しない 1:表示する(デフォルト)
それでは実際に動かしてみましょう。これは内臓カメラ(device_id = 0)で300フレーム分撮影してくれる実行スクリプトです。デフォルトではFPSは30なので、10秒間撮影をしてくれます。
python camera_recorder.py 0 300
カメラのプレビュー映像もちゃんと見えていますね。撮影時にはプレビュー表示があるとかなり便利です。
外部接続のUSBカメラで撮影をしたい場合は、次のようにするとできるはずです。
python camera_recorder.py 1 300
Pythonを使ってDynamoDBにJSONデータをインポート(アップロード)する
Pythonを使って、AWSを代表するNoSQLデータベースであるDynamoDBにJSONデータをインポート(アップロード)する手順およびスクリプトについてご紹介します。
目次
DynamoDBのテーブルを作成する
DynamoDBはNoSQLであるため、テーブル設計時にスキーマを定義する必要がありません。この辺りについては、NoSQLに関する記事でまとめましたので確認してみてください。
hktech.hatenablog.com
AWSにサインインし、DynamoDBマネジメントコンソールにアクセスします。
https://ap-northeast-1.console.aws.amazon.com/dynamodb/home
テーブルを作成します。
今回はセンサーの出力値を格納するテーブルを作成していきます。
プライマリキーは必須項目ですが、RDBMSとは異なり残りのカラム名を指定する必要がありません。テーブル名、プライマリキーはそれぞれ次のように設定します。
- テーブル名: sensor_value
- プライマリキー: id
Python を使ってDynamoDBにJSONデータをインポートする
PythonによるDynamoDBへのデータインポート方法について説明します。
まず、JSON形式のリストデータsensor_data.jsonを用意します。
[{"id": "data_000001", "value": "20"}, {"id": "data_000002", "value": "30"}]
aws_access_key_id, aws_secret_access_keyの中身をそれぞれの環境に合わせて変更し、次のプログラムを実行します。
from boto3.session import Session from decimal import Decimal import json def get_dynamo_table(key_id, access_key, table_name): session = Session( aws_access_key_id=key_id, aws_secret_access_key=access_key, region_name='ap-northeast-1' ) dynamodb = session.resource('dynamodb') dynamo_table = dynamodb.Table(table_name) return dynamo_table def insert_data_from_json(table, input_file_name): with open(input_file_name, "r") as f: json_data = json.load(f) with table.batch_writer() as batch: for record in json_data: record["value"] = Decimal("{}".format(record["value"])) # テーブル側で数値型を指定している場合はこのような処理が必要 batch.put_item(Item=record) print('Successfully inserted data.') if __name__ == '__main__': aws_access_key_id='idhogehoge' aws_secret_access_key='keyhogehoge' input_file_name = './sensor_data.json' dynamo_table = get_dynamo_table(aws_access_key_id, aws_secret_access_key, 'sensor_value') insert_data_from_json(dynamo_table, input_file_name)
DynamoDBマネジメントコンソールにアクセスし、データがちゃんと入っていることを確認できました。
GitLabでブランチを切ってMerge Requestを出す時のコマンドまとめ
GitLabを使ってチームで開発を進めている会社も多いかと思います。開発の仕事に関わっている方だと、「gitでdevelopからブランチ切ってMerge Request出しといて」なんて言われことがあるのではないでしょうか。慣れないうちは意味不明ですよね。本記事では、コマンドラインからMerge Requestを出す方法について初心者でもわかるようにご説明します。
目次
「ブランチを切る」とは
ブランチを切るとは、ある特定のブランチから枝分かれさせる形でそのブランチのコピーを作ることを指します。重要なブランチで直接ファイルを書き換えられてしまったりすると、コードにミスがあったり、他の人がそのブランチでなにかしらの作業をしているときに支障が出てしまうためこのような操作を行います。そのため、1回メインブランチのコピーを作ってそこでファイルなどの変更を行い、それをもとのブランチに戻すという操作をします。この元のブランチに戻す操作をmergeと呼び、そのmergeを依頼することをmerge request といいます。
ブランチを切ってmerge request を出す
開発チームの管理者でない限りはgitに直接mergeする権限がないことも多いため、merge requestを出し、レビューなどをしてもらってからmergeをしてもらうというのがよくある運用方法だと思います。Merge RequestはGitLab特有の表現です。普段GitHubなどを使う方はPull Requestまたはプルリクと同じものであると考えればよいでしょう。今回は次の図のようなツリー構造を持つgitを考えます。
実際の開発環境では、メインのブランチであるmasterブランチと、masterに反映する手前に用意されているdevelopブランチがあることがほとんどです。これを前提に、ブランチを切ってからmerge requestを出すまでの流れについて説明します。
ローカルにレポジトリを初めて落としてくる場合
コマンドラインで、自分がリポジトリを配置したいディレクトリに移動して下記のコマンドを打っていきます。
まずgitレポジトリをローカルにクローンし、developブランチに移動します。
git clone http://repository_name.git git checkout develop
ここで、developからブランチを切って新たな作業ブランチfix-api-bugを作成し、そのブランチに移動します。
git checkout -b fix_api_bug
または
git branch fix_api_bug git checkout fix_api_bug
作業ブランチでファイルを色々書き換えたあとは、git add でインデックスに追加し、git commit でコミット、最後にリモートレポジトリ origin にpushします。
git add <各ファイル名> OR git add . (編集した全てのファイル) git commit -m “fix api bug” git push origin fix_api_bug
すでにローカルにレポジトリをクローン済の場合
まずはdevelopブランチに移動し、git pull することでリモートレポジトリとローカルレポジトリを同じ状態にします。
git checkout develop git pull
残りの流れは、先ほどと同様にgit add, git commit を行い、最後にリモートレポジトリにpushします。
git checkout -b add_knn_model git add <各ファイル名> OR git add . (編集した全てのファイル) git commit -m “add knn model function” git push origin add_knn_model
gitの操作ミスを取り消したいときのコマンド
マージリクエストを出すまでにいくつかコマンドを打つ必要があるため、ミスをしてしまうことがあります。git add . をしたら関係ない周辺ファイルまでインデックスされてしまった、コミット文でスペルミスを見つけてしまったなど、直前の操作を取り消したい場合があると思います。よくあるgitコマンドの取り消し方法について簡単にまとめました。
git add を取り消したい場合
git reset
ちなみにこれは git reset –mixed HEAD を実行していることと同じです。
その辺りは、次の記事を参考にしてみてください。
https://qiita.com/shuntaro_tamura/items/db1aef9cf9d78db50ffe
git commitを取り消したい場合
git reset --soft HEAD^
まとめ
GitLabでブランチを切ってからMerge Requestまでの流れについておさらいしました。実際の業務ではさらに色々なコマンドを覚える必要がありますが、よほどのことがない限りはこれらのステップが理解できていれば問題ないと思います。エンジニアに限らず、ビジネスサイドの方でもgitを使う場面はあると思うので、この機会に確認しておいてはいかがでしょうか。
NoSQLのメリットと選ぶべき理由
NoSQLデータベースに触れる機会があったので、RDBMSとの違いを明らかにしながら利点や特徴についてまとめました。
目次
NoSQL とは
NoSQLとは Not only SQL の略であり、名前の通りデータの処理にSQLを必要としないデータベース (DB) です。1つのキーに対して1つの値を持つキー値型と、1つのキーに対して値だけでなく、配列など自由なデータ形式を持つことができるドキュメント型などが主に使われています。特徴としては、一意のキーに対して1つのデータが決まっているということがあります。これだけ書くと、MySQLやPostgresなどのRDBMSでもキーで検索して値を取り出せるし、そちらの方が便利なのでは?と感じる方も多いと思います。RDBMSとの違いを踏まえて、NoSQLのメリットを紹介していきます。
NoSQL のメリット
NoSQLが生まれた背景には、システムのあり方が多様化する中で、より柔軟なデータベースが求められていったということがあります。RDBMSは、データの複雑な関係構造を保持できる一方で、データの肥大化によって性能が劣化してしまうというデメリットがありました。ゲーム系のサービスや、アドテクノロジー、IoTシステムなど、大量のデータを扱うようなシステムのデータベースでは、データ量やユーザー数の増加・減少に応じて容量を柔軟に拡張したり、縮退させるような弾力的な拡張性が必要とされています。RDBMSは原則として1台のサーバーで動作するように設計されているのに対して、NoSQLはサーバを増加させると性能が改善するため、スケールアウトさせやすいというメリットがあります。そのため、データベース設計時にデータ量についてあまり考慮する必要がなく、ビッグデータと呼ばれるような大規模なデータ量を扱うときには強い味方となってくれるでしょう。
NoSQLを選ぶべき場面
ここまでで、NoSQLのメリットについて理解することができました。それでは実際にシステム要件を考えるときに、どのような場合にNoSQLデータベースを選択するべきかについて考えていきたいと思います。特にNoSQLが適しているのは、追加処理やデータ参照が多いシステムです。NoSQLはこれらの処理に最適化されているため、高速でデータの追加や参照が可能です。先ほど例を挙げたようなゲームのウェブサービスや、アドテク、センサデータなどを1秒間に何百件も扱うようなIoTシステムなどがこれに当てはまります。他にも画像データや音声データなどの非構造データを大量に保持したいときにはRDBMSよりもNoSQLのほうが有効でしょう。また、将来的にカラムを増やす可能性があるような時には、より柔軟なNoSQLデータベースのほうが扱いやすいと思います。
一方で、予約情報などのような厳密な処理が求められるシステムには適していないと言えます。しかし、最近ではこのような課題にも対応したNoSQLもあるようなので、開発するシステムに合ったデータベースを選んでいくことが大事だと思います。
まとめ
NoSQLのメリットと使用すべき場面についてご紹介しました。強みとしては、より柔軟に拡張でき、スケールアウトさせやすいということがありました。大量のデータを追加させていくようなシステムを開発する際は、ぜひ検討してみてはいかがでしょうか。
* 参考記事
https://academy.gmocloud.com/qa/20160509/2284
https://morizyun.github.io/blog/rdbms-nosql-merit-demerit/index.html
https://boxil.jp/mag/a3032/
https://employment.en-japan.com/engineerhub/entry/2017/11/22/110000
リスト内包表記を使ってPythonのリストの要素を条件付きで検索・抽出
コードを書く仕事をしていると、Pythonでリストの中から条件に一致する要素だけを抽出するようなスクリプトを書く機会も少なくないと思います。このようなとき、普通はforループを使って書いてしまいたくなるのですが、Pythonのリスト内包表記を使えば、「これ1行で書けるんじゃね?」となります。for文との違いを示しながら実装例をご紹介します。
目次
Pythonのリストから要素を条件付きで抽出
次のような文字列のリストがあるとします。このリストから文字列などの条件をもとに特定の要素のみを抜き出したいとします。
txt_files = ['データサイエンス集中講義_20180912.txt', ¥ 'データサイエンス集中講義_20180914.txt', ¥ 'Pythonプログラミング_20180912.txt', ¥ 'Pythonプログラミング_20180915.txt', ¥ 'Pythonプログラミング_20180916.txt']
文字列が完全一致する要素を抽出
まずは 'データサイエンス集中講義_20180912.txt' と完全一致する要素を抽出し、新たなリスト output_list に入れたいとします。
forループで記述
普通に書くと次のようになります。
output_list = [] for file in txt_files: if file == 'データサイエンス集中講義_20180912.txt': output_list.append(file)
リスト内包表記で記述
リスト内包表記を使えば1行で次のように書くことができます。簡単ですよね?
[file for file in txt_files if 'データサイエンス集中講義_20180912.txt' == file]
コードの比較
各変数がどのように対応しているかを色別でみていくと理解が深まると思います。
forループ
> for file in txt_files:
> ____if file == 'データサイエンス集中講義_20180912.txt':
>________output_list.append(file)
リスト内包表記
> [file for file in txt_files if 'データサイエンス集中講義_20180912.txt' == file]
ある文字列を含む要素を抽出
今度は、'データサイエンス' という文字列を含む要素を抽出したいとします。in 句を使うと、特定の文字列が含まれているかどうかを分岐させることができます。
forループで記述
さきほどと同じように、forループを使って書くと次のようになります。
output_list = [] for file in txt_files: if 'データサイエンス' in file: output_list.append(file)
リスト内包表記で記述
リスト内包表記を使えば、こちらも1行ですっきりと書くことができます。
[file for file in txt_files if 'データサイエンス' in file]
文字列の一部を抽出
最後に、実務で発生しそうな例を紹介します。こちらではテキストファイルのファイル名のうち、日付部分のみを取り出して条件付きで抽出したい場合を考えます。
forループで記述
pythonは文字列変換に強いので、1つの文だけで多くのことができます。
output_list = [] for file in txt_files: if file.replace('.txt', '').split('_')[1] == '20180912': output_list.append(file)
リスト内包表記で記述
難しい処理も次のように1行で書くことができます。
[file for file in txt_files if '20180912' == file.replace('.txt', '').split('_')[1]]
PythonのおすすめWebフレームワーク flask 入門
初心者向けに、おすすめのPythonのWeb開発フレームワークのひとつである flask の概要と簡単な実装例についてご紹介します。
目次
PythonのWebフレームワーク
Web開発というと Ruby の Ruby on Rails や PHP の Cake PHPや Laravel などが有名で、Pythonはあまり聞く機会が多くないかもしれません。しかし、PythonにもInstagramなどで使用されているDjangoという強力なWeb開発フレームワークや、Djangoと比べて軽量なフレームワークであるflaskなどがあります。今回の記事ではよりシンプルに実装ができるflaskについてご紹介します。
Django と flask の違い
簡単に Django と flask の用途の違いについて説明します。
Django
- 多くの機能を備えており、大規模なウェブ開発が可能。
- 実際のサービスでも多く使われている。
flask で Webサイトを作る
flaskを使うと、実にさまざまなウェブアプリケーションを開発することができます。今回はそのようなウェブアプリケーションを開発するための型を用意し、今後の開発の土台を作成していきたいと思います。
準備
まずは flask をインストールします。
pip install flask
ファイル構成
適当なディレクトリを作成し、ウェブサイトに必要なファイルを用意していきます。
flaskでは、基本的に以下のようなファイルの置き方に従います。
- ルートディレクトリ(メインのフォルダ)下にメインのpythonファイルを直接配置します。
- cssやjsファイルは、staticフォルダの下に配置します。
- htmlファイルは、templates フォルダの下に配置します。
├ app.py ├ static │ ├ css │ │ ∟style.css │ ├ js ├ templates ∟index.html
Python ファイル
pythonプログラムをwebアプリケーションのサーバーとして動かします。
@app.route('/') 以下の index 関数では、ウェブサイトのルートディレクトリにアクセスしたときの挙動について記述します。今回は、index.htmlを表示するように記述しています。
from flask import Flask, request, render_template app = Flask(__name__) @app.route('/') def index(): return render_template('index.html') if __name__ == "__main__": app.debug = True app.run(host='0.0.0.0', port=8080)
html ファイル
htmlファイルを用意します。cssやjsファイルも、通常と同じように読み込みます。
<html> <head> <link rel='stylesheet' type='text/css' href='static/css/style.css'> <script src="http://code.jquery.com/jquery-latest.js"></script> </head> <body> <div id='header'> <h1>デモサイト</h1> </div> <div id='contents'> ここにコンテンツを作成していきましょう! </div> </body> </html>
Webサイトを立ち上げる
pythonプログラムを実行する
app.pyが置かれたルートディレクトリへ移動し、pythonスクリプトを実行します。
python app.py
下記のような表示が出ていれば成功です。
* Running on http://0.0.0.0:8080/ (Press CTRL+C to quit) * Restarting with stat
サーバーにアクセスしてwebサイトを確認する
ブラウザを開き、下記のURLにアクセスしましょう。
http://0.0.0.0:8080/
このようにサイトが表示されていれば、flaskでWeb開発を行う土台は完成です。
まとめ
PythonのWeb開発フレームワークであるflaskについて紹介しました。今回はウェブサイトの型だけを作成したので、次回はこの土台を使ってflaskおよびpythonの強みを活かしたウェブアプリケーションを開発していきます。