Auto-identify statsmodels' ARIMA/SARIMA in python

Posted on January 8, 2017 by Ilya

In python’s statsmodels ARIMA/ARIMAX/SARIMAX is great, but it lacks automatic identification routine. What follows is the solution using grid search.

Let’s do some imports.


```python
import pandas as pd
import numpy as np

from statsmodels.tsa.arima_model import ARIMA
from statsmodels.tsa.statespace.sarimax import SARIMAX
from pprint import pprint
from statsmodels.tsa.stattools import adfuller
from scipy.optimize import brute

from scipy import stats

import matplotlib.pyplot as plt
import warnings
```

To identify the order of differencing one can:

  1. use ADF stationarity test on the original data and if it is stationary, then the order of differencing, \(d=0\). Stop here.
  2. if non-stationary, repeat differencing \(k\) times till stationary, then \(d=k\). Done.

```python
def findDorder(series, maxD=10):
    if np.any(pd.isnull(series)):
        raise ValueError("series contains NaN")
    s = series.values.flatten()
    ds = range(maxD)
    ret = [ (d, adfuller(np.diff(s, d))[1])
            for d in ds
          ]
    pprint(
        list(map(
            lambda x: {'d':x[0], 'p-value': x[1]},
            sorted(ret, key=lambda x:x[1])
            )))
    print(min(ret, key=lambda x:x[1]))
    return ret
```

For other parameters of ARIMA, \(p\) order autoregressive model and, \(q\) order of moving average model, one can use grid search to minimize AIC/BIC. Same for \(d\) (or use \(d\) found with ADF and differencing):


```python
def find_best_pdq_ARIMA(
        endog,
        grid,
        huge=float('inf'),
        measure=lambda x: x.aic,
        **kwa
        ):

    def fail(e):
        print(e)
        return huge 

    def objfunc(order):
        with warnings.catch_warnings():
            try:
                res = ARIMA(endog, order, **kwa).fit()
                r = measure(res)
            except BaseException as e:
                r = fail(e)
            except Warning as e:
                r = fail(e)
            if not np.isfinite(r):
                r = huge
            print((tuple(order), r))
            return r

    return tuple(map(int, brute(objfunc, grid, finish=None)))


def reTuple(paramsFull):
    paramsFull = list(map(int, paramsFull))
    param = paramsFull[:3]
    param_seasonal = paramsFull[3:]
    return ( tuple(param), tuple(param_seasonal) )


def find_best_pars_SARIMAX(
        endog,
        grid,
        seasonal_grid,
        huge=float('inf'),
        measure=lambda x: x.aic,
        **kwa
        ):

    def fail(e):
        print(e)
        return huge

    def objfunc(paramsFull):
        parTup = reTuple(paramsFull)
        param, param_seasonal = parTup

        with warnings.catch_warnings():
            try:
                res = SARIMAX(
                            endog,
                            order=param,
                            seasonal_order=param_seasonal,
                            **kwa
                            ).fit()
                r = measure(res)
            except BaseException as e:
                r = fail(e)
            except Warning as e:
                r = fail(e)
            if not np.isfinite(r):
                r = huge
            print( ( parTup, r) )
            return r

    fullGrid = tuple(list(grid) + list(seasonal_grid))
    return reTuple(map(int,
                       brute(objfunc, fullGrid, finish=None)))
```

Run some tests:


```python
def test1():
    t = np.linspace(0, 20, 100)
    y =\
          0.1*np.sin(2*np.pi*t/7.0)\
        + np.random.randn(len(t))\
        + 1.5 * np.sin(2*np.pi*t/1)\
        + 0.2*t
    grid = (slice(0, 3+1), slice(0, 2+1), slice(0, 3+1))
    best_order = find_best_pdq_ARIMA(y, grid)
    print(best_order)
    res = ARIMA(y, best_order).fit()
    ypred = res.predict(typ='levels')

    plt.figure()
    pd.Series(y, index=t).plot()
    pd.Series(ypred, index=t[:len(ypred)]).plot()


def test2():
    t = np.linspace(0, 20, 100)
    y =\
          0.1*np.sin(2*np.pi*t/7.0)\
        + np.random.randn(len(t))\
        + 1.5 * np.sin(2*np.pi*t/1)\
        + 0.2*t
    series = pd.Series(y, index=t)
    grid  = (slice(0, 3+1), slice(0, 1+1), slice(0, 3+1))
    sgrid = (slice(0, 3+1), slice(0, 1+1), slice(0, 3+1), slice(0, 10, 4))
    best_order = find_best_pars_SARIMAX(series.dropna().values,
                                        grid, sgrid,
                                        enforce_stationarity=False,
                                        enforce_invertibility=False)
    print(best_order)
    param, param_seasonal = best_order
    kwa = dict(enforce_stationarity=False,
               enforce_invertibility=False)
    res = SARIMAX(
                series.dropna().values,
                order=param,
                seasonal_order=param_seasonal,
                **kwa
                ).fit()

    plt.figure()
    series.plot()
    ypred = res.predict(typ='levels')
    pd.Series(ypred, index=series.index[:len(ypred)]).plot()

    plt.figure()
    res.plot_diagnostics(figsize=(15, 12))

    print(stats.normaltest(res.resid))


if __name__ == '__main__':
    test1()
    test2()
```