Home
  • Home
  • Categories
  • Tags
  • Archives

Follow the Regularized Leader with Adaptive Decaying Proximal

Derivation of FTRL-ADP¶

Problem¶

\begin{equation} \label{eq:ftrl_dp} \begin{split} w_{t+1} = \underset{w}{\operatorname{argmax}} &\bigg\{g^\top _{1:t}w + \lambda_1\|w\|_1 + \frac{1}{2}\lambda_2\|w\|_2^2 + \frac{1}{2}\lambda_p\sum_{s=1}^{t}\sigma_{t,s} \|w-w_s\|^2_2\bigg\} \\ &\text{ in which } g^\top _{1:t} = \sum_{i=1}^{t}g^\top _t \text{ and } \sigma_{t,s} = \gamma^{t-s} \end{split} \end{equation}

This is a variant of the FTRL-proximal proposed by McMahan et al. in Ad click prediction: a view from the trenches.

Theorem¶

\begin{equation} \label{eq:update_func} w_{t+1,i} = \begin{cases} 0 \hskip 27mm \text{ if $\| z_{t,i}\|_1 \le \lambda_1$} \\[2ex] -\frac{z_{t,i}-\lambda_1 \mathrm{sign}(z_{t,i})}{\lambda_2+\lambda_p\frac{1-\gamma^{t}}{1-\gamma}} \hskip 4mm\text{ otherwise.} \end{cases} \end{equation}$$\text{in which } z_t=g_{1:t}-\lambda_p\sum_{s=1}^t \gamma^{t-s} w_s $$

This theorem leads to an efficient recursive algorithm to compute $w_t$ at each time step.

Theorem¶

Suppose that $\|w_t\|_2 \le R$ and $\|g_t\|_2 \le G$. With $\lambda_1=\lambda_2=0$ and $\lambda_p=1$, we have the following regret bound for FTRL-DP: \begin{equation}\label{eg:dpregret} Regret(w^*) \le 2R^2 \frac{1-\gamma^T}{1-\gamma} + \frac{G^2}{2} \frac{1-\gamma}{\gamma^T}\sum{t=1}^{T}\frac{\gamma^t}{1-\gamma^t} \end{equation}

We prove that a proper choice of $\gamma$ can lead to a sublinear growth of the expression on the right-hand side. Specifically, if we choose $\gamma = 1 - \frac{\ln T}{2T}$, we have the following sublinear regret bound: $$Regret(w^*) \le 4R^2\frac{T}{\ln T} + \frac{G^2}{2}\frac{(1 + \ln T)}{(1 - \frac{\ln T}{2T})^T}$$

This theorem leads to a formula to update the decaying rate so as to ensure sublinear regret bound.

FTRL-ADP and Concept Drift¶

import os
from sklearn.datasets import load_svmlight_file
import numpy as np
import matplotlib.pyplot as plt
from ftrl_adp import FTRL_ADP
from IPython.display import HTML
from utils import animate
import matplotlib
from scipy.sparse import coo_matrix, hstack, csr_matrix 
import matplotlib.animation as animation
%matplotlib notebook

font = {'size'   : 4}
matplotlib.rc('font', **font)
SIM_SPEED = 100

def animate(i):
    global X_input, Y_label, SIM_SPEED, r_bound_last
    
    if float(X_input.shape[0]) / SIM_SPEED < i: return

    r_bound = min((i+1)*SIM_SPEED,X_input.shape[0])
    l_bound = max(r_bound-200,0)

    if r_bound <= X_input.shape[0]:
        X_batch = X_input[l_bound:r_bound,1:3]
        Y_batch = Y_label[l_bound:r_bound]

        # Draw data points
        plot_points.set_offsets(X_batch.toarray())
        plot_points.set_array(Y_batch.astype(np.int))
        plot_points.set_cmap('cool')
        
        alphas = np.linspace(0.0, 1, X_batch.shape[0])
        rgba_colors = np.zeros((X_batch.shape[0],4))
        rgba_colors[:, 3] = alphas
        rgba_colors[:, 2] = Y_batch.astype(np.int)/4.
        plot_points.set_color(rgba_colors)

        # Update the model
        for alg in algs:
            classifier = algs[alg]

            for row in xrange(r_bound_last, r_bound):
                indices = X_input[row].indices
                x       = X_input[row].data
                y       = Y_label[row]
                p, decay = classifier.fit(indices, x, y)
                error = [int(np.abs(y-p)>0.5)]
                
                all_errors[alg].append(error)
                all_decays[alg].append(decay)
            
            # Draw the decision line
            indices = np.arange(X_input.shape[1])
            x       = X_input[r_bound-1].toarray().ravel()
            weight = classifier.weight_update(indices)
            offset = np.sum(weight[3:] * x[3:]) + weight[0] * x[0]

            if weight[2] != 0:
                line_points = np.array([[xlimit[0], (-offset-weight[1]*xlimit[0])/weight[2]],
                                        [xlimit[1], (-offset-weight[1]*xlimit[1])/weight[2]]])
                d_lines[alg].set_data(line_points[:,0], line_points[:,1])

            # Draw the error lines
            errors = all_errors[alg]
            e_lines[alg].set_data(np.arange(len(errors)), np.cumsum(errors)/(np.arange(len(errors))+1.0))
            
            # Draw the decay line
            if 'adp' in alg:
                decays = all_decays[alg]
                c_lines[alg].set_data(np.arange(len(decays)), decays)

    r_bound_last = r_bound
    return plot_points
# Load data
e_range = (0.2, .4)
d_range = (0.9, 1.)

X_input, Y_label = load_svmlight_file('dataset.txt')
temp = csr_matrix(np.ones((X_input.shape[0], 1)))
X_input = hstack([temp, X_input]).tocsr()[:,:3]        

ftrl_adp = FTRL_ADP(decay = 1.0, L1=0., L2=0., LP = 1., adaptive=True, n_inputs=X_input.shape[1])

algs = {'ftrl_adp':ftrl_adp, 
        }
clrs = {'ftrl_adp':'blue', 
        }

fig = plt.figure(figsize=(6,2))

# Make the scatter plot on the left
xlimit = (np.min(X_input, axis=0).toarray()[0,1]-0.1, np.max(X_input, axis=0).toarray()[0,1]+0.1)
ylimit = (np.min(X_input, axis=0).toarray()[0,2]-0.1, np.max(X_input, axis=0).toarray()[0,2]+0.1)
scatterplot = fig.add_subplot(131, autoscale_on=False, xlim=xlimit, ylim=ylimit)
plot_points = scatterplot.scatter(np.array([0]), np.array([0]))

d_lines = {}
for alg in algs:
    decision_line, = scatterplot.plot([], [], 'r-', linewidth=2, c=clrs[alg])
    d_lines[alg] = decision_line

# Make the error plot at the middle
errorplot = fig.add_subplot(132, autoscale_on=False, xlim=(-1, X_input.shape[0]), ylim=e_range)
errors = []

all_errors = {}
for alg in algs:
    all_errors[alg] = []

e_lines = {}
for alg in algs:
    error_line, = errorplot.plot([], [], 'r-', linewidth=1, c=clrs[alg])
    e_lines[alg] = error_line

# Make the decay plot on the right
decayplot = fig.add_subplot(133, autoscale_on=False, xlim=(-1, X_input.shape[0]), ylim=d_range)
errors = []

all_decays = {}
for alg in algs:
    all_decays[alg] = []

c_lines = {}
for alg in algs:
    decay_line, = decayplot.plot([], [], 'r-', linewidth=1, c=clrs[alg])
    c_lines[alg] = decay_line

r_bound_last = 0
anim = animation.FuncAnimation(fig, animate, frames=350, interval=200, repeat=False)
HTML(anim.to_html5_video())
#plt.show()
HTML(anim.to_html5_video())
Your browser does not support the video tag.
Comments
comments powered by Disqus

  • « 3D Vehicle Detection
  • Self Driving Toy Car »

Published

Mei 15, 2017

Category

Theoretical ML

Tags

  • Analaysis 2
  • Linear Algebra 6
  • Powered by Pelican. Theme: Elegant by Talha Mansoor