Cucco’s Compute Hack

コンピュータ関係の記事を書いていきます。

Neural Network Consoleによる学習済みニューラルネットワークの利用

環境構築

ベースはanaconda。nnablaがなく、importに失敗するのでパッケージ追加しました。以下のコマンドを管理者で開いたコンソールで実行。pip自体の更新が必要な場合もあり。

pip install ipykernel
pip install nnabla

コードの作成と推論の実行

以下を参考に、アヤメのデータで学習して、推論するところまで。
学習はGUI使いました。学習済みパラメータファイルが20180520_142204というフォルダにある想定です。

チュートリアル:Neural Network Consoleによる学習済みニューラルネットワークのNeural Network Librariesを用いた利用方法2種 – Docs - Neural Network Console

import nnabla as nn
import nnabla.functions as F
import nnabla.parametric_functions as PF

#loss関数を削るのでyが不要になる。
#def network(x, y,test=False):
def network(x, test=False):
    # Input -> 4
    # BatchNormalization
    with nn.parameter_scope('BatchNormalization'):
        h = PF.batch_normalization(x, (1,), 0.9, 0.0001, not test)
    # Affine -> 50
    with nn.parameter_scope('Affine'):
        h = PF.affine(h, (50,))
    # ReLU
    h = F.relu(h, True)
    # Affine_2 -> 20
    with nn.parameter_scope('Affine_2'):
        h = PF.affine(h, (20,))
    # ReLU_2
    h = F.relu(h, True)
    # Dropout
    if not test:
        h = F.dropout(h)
    # Affine_3 -> 3
    with nn.parameter_scope('Affine_3'):
        h = PF.affine(h, (3,))
    # Softmax
    h = F.softmax(h)
    # SquaredError
    # 不要なのでコメントアウト
    # h = F.squared_error(h, y)
    return h

以降が追加した推論用のコード。実際には上のコードと同じファイルに記載。

# load parameters
# \を/に書き換え必要あり。オリジナルのブログはシングルクオートが全角なので書き換えが必要。
nn.load_parameters('C:/iris_sample/iris_sample.files/20180520_142204/parameters.h5')

# Prepare input variable
# 3つ推論する場合
x=nn.Variable((3,4))
# 1つ推論する場合
x1=nn.Variable((1,4))
x2=nn.Variable((1,4))
x3=nn.Variable((1,4))

# Let input data to x.d
# x.d = ...
x.d=[[5.1, 3.5, 1.4, 0.2],[7,   3.2, 4.7, 1.4],[6.3, 3.3, 6,  2.5]]
#yの1番目が大きな値になる。listを作って渡せばいいのは楽。
x1.d=[5.1, 3.5, 1.4, 0.2]
#yの2番目が大きな値になる。
x2.d=[7,   3.2, 4.7, 1.4]
#yの3番目が大きな値になる。
x3.d=[6.3, 3.3, 6,  2.5]

# Build network for inference
# test=Trueで、ドロップアウトの機能を停止する。BatchNormalizationへの影響は不明。
y = network(x, test=True)

# Execute inference
y.forward()
print(y.d)

実行結果。指数表示なので大小関係わかりにくいがあってるはず。

#[[9.9891686e-01 9.9650992e-04 8.6744032e-05]
# [7.5185834e-03 7.6037079e-01 2.3211062e-01]
# [7.9514321e-06 3.5711315e-02 9.6428072e-01]]