ウシジの深層・強化学習の学習

深層学習、強化学習に関して、学んだことや論文、ニュース等を紹介していきます。よろしくお願いします。

Coursera Machine Learning: Week10 Large Scale Machine Learning

移転しました。

f:id:ushiji:20191203125437p:plain

CourseraのMachine Learningについてまとめています。 前回はWeek9の後半、Recommender Systemsについてまとめました。

今回は、Week10 Large Scale Machine Learningについて学びます。

  

 

Week10

Large Scale Machine Learning

Week6の後半で、大量のデータを学習に使うことにより、そこまで優れたアルゴリズムでなくとも、優れたアルゴリズムと同等かそれ以上の結果を出すことができるといったことを学びました。Week10では、そのような大きなデータをどのように扱うかについて学びます。

 

Stochastic gradient descent

これまで、様々な場面でGradient descentを用いてきましたが、ここではそれを改変した、Stochastic gradient descent(確率的最急降下法)について学びます。

まず、線形回帰モデルにGradient descentを適用した場合は、Hypothesisやコスト関数、Gradient descentは下記の数式で表せます。

f:id:ushiji:20191225122756p:plain

Linear regression with gradient descent

このとき、トレーニング用のデータ量mがかなり大きい場合、上記の数式で青の破線で囲んだコスト関数の微分項の計算コストがかなり大きくなってしまいます。ちなみに、これまで使ってきたこのGradient descentのことを、"Batch gradient descent"と言います。"Batch"とは、毎回の計算で全てのトレーニングデータを用いることを表しています。

このBatch gradient descentでは、全てのトレーニングデータを用いているのですが、計算量を減らすために、各イテレーションで1つのトレーニングデータだけを用いるのがStochastic gradient descentです。Stochastic gradient descentでは、コスト関数をトレーニングデータ全体に対するコスト関数ではなく、下記のようにトレーニングデータx(i), y(i)に対するコスト関数として定義します。ちなみに、このコスト関数をトレーニングデータ1~mまで計算し、平均をとれば、Batch gradient descentのコスト関数と一致します。

f:id:ushiji:20191225125239p:plain

Stochastic gradient descent

Stochastic gradient descentでは、まずは、トレーニングデータをシャッフルし、各トレーニングデータに関して、上記で記載した数式(定義したコスト関数の微分 × Learning rate)を用いてパラメータθを更新します。1~mまで全てのトレーニングデータを用いてパラメータをアップデートし、それを複数回繰り返します。これで、パラメータをフィットさせていきます。

Batch gradient descentでは、全てのトレーニングデータを用いるため、最適解(コストのGlobal minimum)に向けてパラメータが更新されていきますが、Stochastic gradient descentでは、ランダムにシャッフルされたトレーニングデータを1つずつ用いてパラメータを更新していくため、まっすぐにGlobal minimumに向かうわけではなく、蛇行しながら最適化されていくような形になります。また、Stochastic gradient descentでは、最終的には収束せず、Global minimumの周囲をうろちょろするような形になります。ただし、Global minimumのすぐそばの領域であれば、Global minimumに極めて近いため、仮説としては十分なようです。

また、パラメータの更新を繰り返す回数ですが、1~10回程度が一般的なようです。

 

Mini-Batch Gradient Descent

Batch gradient descentでは、全てのトレーニングデータを用いてパラメータを更新し、Stochastic gradient descentでは、1つのトレーニングデータを用いてパラメータを更新しました。その間の手法として、いくつかのトレーニングデータを用いてパラメータを更新していく手法のことを"Mini-batch gradient descent"と言います。また、その際に用いるデータの数をMini-batch sizeと言います。トレーニングデータのサイズmが1,000、ミニバッチサイズbが10の場合、Mini-batch gradient descentのアルゴリズムは、下記のように記載できます。

f:id:ushiji:20191225131125p:plain

Mini-batch gradient descent

ミニバッチサイズは、2~100の間で、10程度が一般的なようです。

 

Stochastic gradient descent convergence 

ここまで、Stochastic gradient descentとMini-batch gradient descentのアルゴリズムについて学びました。次に、Stochastic gradient descentにおいて、どのように収束しているかを確認するのかについて学びます。また、Learning rate αの小技も紹介されます。

収束の確認をする際、Batch gradient descentでは、イテレーション数に対してコスト値をプロットし、収束を確認していました(Week6前半のLearning Curves参照)。Stochastic gradient descentでは、その確認のために、全てのトレーニングデータを加味したコストをそのために計算するのであれば、せっかく1つだけのデータを用いて計算量を減らしたのに、計算量が増えてしまいます。そのため、本記事の上の方で定義したトレーニングデータx(i), y(i)に対するコスト関数を用いて、ある程度の計算ごとに、その平均のコストを計算、プロットし、収束しているかを確認します。例えば、1,000個分のデータに対して計算するごとに、その1,000個分のデータのコストの平均を計算し、プロットします。

次に、Learning rate αについて。Stochastic gradient descentでは、最終的には収束せず、Global minimumの周囲をうろちょろするような形になります。これをよりGlobal minimumに近づけるために、Learning rateを定数ではなく、下記のような数式に置き換え、イテレーションに合わせて減少させていくという手法があるそうです。

f:id:ushiji:20191225134301p:plain

Learning rate

ここで、Const1、Const2は定数です。ただ、そこまでGlobal minimumにこだわらなければ、αを定数としてしまっても十分なようです。

 

Online learning

これまでは、すでに準備されたトレーニングデータから学習するアルゴリズムについて学んできました。ここでは、次々と新しいデータが得られ、それを用いてモデルを更新していく場合に用いる、オンライン学習について学びます。

例えば、Web上で受け付ける配達サービスを運営しており、出荷元から出荷先までの地点をもとに、配達価格を提示していたとします。ユーザーがサービスを利用した場合をy=1、利用しなかった場合をy=0とし、Logistic regressionの問題として扱い、どのような場合にy=1となるのかを予測し、配達価格を最適化したいとします。この時に扱うFeatureは、ユーザーのプロパティや出荷元/出荷先、提示価格のデータで、新たにユーザーのデータが得られるたびに、下記のように計算し、予測のパラメータを更新していきます。

f:id:ushiji:20191225140512p:plain

Online learning

Map-reduce and data parallelism

これまでは、1台のPCで計算を行うことを前提としてきましたが、複数台のPCに作業を割り振り、計算時間を短縮することも可能です。

その1つの方法が、Map-reduceで、例えば400のトレーニングデータがあるとして、下記のように4つのPCで100ずつ計算し、その結果をマスターのサーバで集め、パラメータの更新を計算するといったことを行います。これにより、およそ4倍のスピードアップを行うことができます。

f:id:ushiji:20191225142243p:plain

Map-reduce

 

プログラミング演習

Week10では、プログラミング演習はありません。(Week11もなく、Week9でプログラミング演習は最後となります。)

 

 

 

次回は、Week11 Application Example: Photo OCRについてまとめます。長かった講義も次で最後です。

ushiji.hatenablog.com

 

 

コース全体の目次とそのまとめ記事へのリンクは、下記の記事にまとめていますので、参照ください。

ushiji.hatenablog.com