
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "auto_examples/model_selection/plot_grid_search_digits.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        Click :ref:`here <sphx_glr_download_auto_examples_model_selection_plot_grid_search_digits.py>`
        to download the full example code

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_auto_examples_model_selection_plot_grid_search_digits.py:


============================================================
Custom refit strategy of a grid search with cross-validation
============================================================

This examples shows how a classifier is optimized by cross-validation,
which is done using the :class:`~sklearn.model_selection.GridSearchCV` object
on a development set that comprises only half of the available labeled data.

The performance of the selected hyper-parameters and trained model is
then measured on a dedicated evaluation set that was not used during
the model selection step.

More details on tools available for model selection can be found in the
sections on :ref:`cross_validation` and :ref:`grid_search`.

.. GENERATED FROM PYTHON SOURCE LINES 19-26

The dataset
-----------

We will work with the `digits` dataset. The goal is to classify handwritten
digits images.
We transform the problem into a binary classification for easier
understanding: the goal is to identify whether a digit is `8` or not.

.. GENERATED FROM PYTHON SOURCE LINES 26-30

.. code-block:: default

    from sklearn import datasets

    digits = datasets.load_digits()








.. GENERATED FROM PYTHON SOURCE LINES 31-34

In order to train a classifier on images, we need to flatten them into vectors.
Each image of 8 by 8 pixels needs to be transformed to a vector of 64 pixels.
Thus, we will get a final data array of shape `(n_images, n_pixels)`.

.. GENERATED FROM PYTHON SOURCE LINES 34-41

.. code-block:: default

    n_samples = len(digits.images)
    X = digits.images.reshape((n_samples, -1))
    y = digits.target == 8
    print(
        f"The number of images is {X.shape[0]} and each image contains {X.shape[1]} pixels"
    )





.. rst-class:: sphx-glr-script-out

 Out:

 .. code-block:: none

    The number of images is 1797 and each image contains 64 pixels




.. GENERATED FROM PYTHON SOURCE LINES 42-44

As presented in the introduction, the data will be split into a training
and a testing set of equal size.

.. GENERATED FROM PYTHON SOURCE LINES 44-48

.. code-block:: default

    from sklearn.model_selection import train_test_split

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=0)








.. GENERATED FROM PYTHON SOURCE LINES 49-55

Define our grid-search strategy
-------------------------------

We will select a classifier by searching the best hyper-parameters on folds
of the training set. To do this, we need to define
the scores to select the best candidate.

.. GENERATED FROM PYTHON SOURCE LINES 55-58

.. code-block:: default


    scores = ["precision", "recall"]








.. GENERATED FROM PYTHON SOURCE LINES 59-70

We can also define a function to be passed to the `refit` parameter of the
:class:`~sklearn.model_selection.GridSearchCV` instance. It will implement the
custom strategy to select the best candidate from the `cv_results_` attribute
of the :class:`~sklearn.model_selection.GridSearchCV`. Once the candidate is
selected, it is automatically refitted by the
:class:`~sklearn.model_selection.GridSearchCV` instance.

Here, the strategy is to short-list the models which are the best in terms of
precision and recall. From the selected models, we finally select the fastest
model at predicting. Notice that these custom choices are completely
arbitrary.

.. GENERATED FROM PYTHON SOURCE LINES 70-167

.. code-block:: default


    import pandas as pd


    def print_dataframe(filtered_cv_results):
        """Pretty print for filtered dataframe"""
        for mean_precision, std_precision, mean_recall, std_recall, params in zip(
            filtered_cv_results["mean_test_precision"],
            filtered_cv_results["std_test_precision"],
            filtered_cv_results["mean_test_recall"],
            filtered_cv_results["std_test_recall"],
            filtered_cv_results["params"],
        ):
            print(
                f"precision: {mean_precision:0.3f} (±{std_precision:0.03f}),"
                f" recall: {mean_recall:0.3f} (±{std_recall:0.03f}),"
                f" for {params}"
            )
        print()


    def refit_strategy(cv_results):
        """Define the strategy to select the best estimator.

        The strategy defined here is to filter-out all results below a precision threshold
        of 0.98, rank the remaining by recall and keep all models with one standard
        deviation of the best by recall. Once these models are selected, we can select the
        fastest model to predict.

        Parameters
        ----------
        cv_results : dict of numpy (masked) ndarrays
            CV results as returned by the `GridSearchCV`.

        Returns
        -------
        best_index : int
            The index of the best estimator as it appears in `cv_results`.
        """
        # print the info about the grid-search for the different scores
        precision_threshold = 0.98

        cv_results_ = pd.DataFrame(cv_results)
        print("All grid-search results:")
        print_dataframe(cv_results_)

        # Filter-out all results below the threshold
        high_precision_cv_results = cv_results_[
            cv_results_["mean_test_precision"] > precision_threshold
        ]

        print(f"Models with a precision higher than {precision_threshold}:")
        print_dataframe(high_precision_cv_results)

        high_precision_cv_results = high_precision_cv_results[
            [
                "mean_score_time",
                "mean_test_recall",
                "std_test_recall",
                "mean_test_precision",
                "std_test_precision",
                "rank_test_recall",
                "rank_test_precision",
                "params",
            ]
        ]

        # Select the most performant models in terms of recall
        # (within 1 sigma from the best)
        best_recall_std = high_precision_cv_results["mean_test_recall"].std()
        best_recall = high_precision_cv_results["mean_test_recall"].max()
        best_recall_threshold = best_recall - best_recall_std

        high_recall_cv_results = high_precision_cv_results[
            high_precision_cv_results["mean_test_recall"] > best_recall_threshold
        ]
        print(
            "Out of the previously selected high precision models, we keep all the\n"
            "the models within one standard deviation of the highest recall model:"
        )
        print_dataframe(high_recall_cv_results)

        # From the best candidates, select the fastest model to predict
        fastest_top_recall_high_precision_index = high_recall_cv_results[
            "mean_score_time"
        ].idxmin()

        print(
            "\nThe selected final model is the fastest to predict out of the previously\n"
            "selected subset of best models based on precision and recall.\n"
            "Its scoring time is:\n\n"
            f"{high_recall_cv_results.loc[fastest_top_recall_high_precision_index]}"
        )

        return fastest_top_recall_high_precision_index









.. GENERATED FROM PYTHON SOURCE LINES 168-173

Tuning hyper-parameters
-----------------------

Once we defined our strategy to select the best model, we define the values
of the hyper-parameters and create the grid-search instance:

.. GENERATED FROM PYTHON SOURCE LINES 174-187

.. code-block:: default

    from sklearn.model_selection import GridSearchCV
    from sklearn.svm import SVC

    tuned_parameters = [
        {"kernel": ["rbf"], "gamma": [1e-3, 1e-4], "C": [1, 10, 100, 1000]},
        {"kernel": ["linear"], "C": [1, 10, 100, 1000]},
    ]

    grid_search = GridSearchCV(
        SVC(), tuned_parameters, scoring=scores, refit=refit_strategy
    )
    grid_search.fit(X_train, y_train)





.. rst-class:: sphx-glr-script-out

 Out:

 .. code-block:: none

    All grid-search results:
    precision: 1.000 (±0.000), recall: 0.854 (±0.063), for {'C': 1, 'gamma': 0.001, 'kernel': 'rbf'}
    precision: 1.000 (±0.000), recall: 0.257 (±0.061), for {'C': 1, 'gamma': 0.0001, 'kernel': 'rbf'}
    precision: 1.000 (±0.000), recall: 0.877 (±0.069), for {'C': 10, 'gamma': 0.001, 'kernel': 'rbf'}
    precision: 0.968 (±0.039), recall: 0.780 (±0.083), for {'C': 10, 'gamma': 0.0001, 'kernel': 'rbf'}
    precision: 1.000 (±0.000), recall: 0.877 (±0.069), for {'C': 100, 'gamma': 0.001, 'kernel': 'rbf'}
    precision: 0.905 (±0.058), recall: 0.889 (±0.074), for {'C': 100, 'gamma': 0.0001, 'kernel': 'rbf'}
    precision: 1.000 (±0.000), recall: 0.877 (±0.069), for {'C': 1000, 'gamma': 0.001, 'kernel': 'rbf'}
    precision: 0.904 (±0.058), recall: 0.890 (±0.073), for {'C': 1000, 'gamma': 0.0001, 'kernel': 'rbf'}
    precision: 0.695 (±0.073), recall: 0.743 (±0.065), for {'C': 1, 'kernel': 'linear'}
    precision: 0.643 (±0.066), recall: 0.757 (±0.066), for {'C': 10, 'kernel': 'linear'}
    precision: 0.611 (±0.028), recall: 0.744 (±0.044), for {'C': 100, 'kernel': 'linear'}
    precision: 0.618 (±0.039), recall: 0.744 (±0.044), for {'C': 1000, 'kernel': 'linear'}

    Models with a precision higher than 0.98:
    precision: 1.000 (±0.000), recall: 0.854 (±0.063), for {'C': 1, 'gamma': 0.001, 'kernel': 'rbf'}
    precision: 1.000 (±0.000), recall: 0.257 (±0.061), for {'C': 1, 'gamma': 0.0001, 'kernel': 'rbf'}
    precision: 1.000 (±0.000), recall: 0.877 (±0.069), for {'C': 10, 'gamma': 0.001, 'kernel': 'rbf'}
    precision: 1.000 (±0.000), recall: 0.877 (±0.069), for {'C': 100, 'gamma': 0.001, 'kernel': 'rbf'}
    precision: 1.000 (±0.000), recall: 0.877 (±0.069), for {'C': 1000, 'gamma': 0.001, 'kernel': 'rbf'}

    Out of the previously selected high precision models, we keep all the
    the models within one standard deviation of the highest recall model:
    precision: 1.000 (±0.000), recall: 0.854 (±0.063), for {'C': 1, 'gamma': 0.001, 'kernel': 'rbf'}
    precision: 1.000 (±0.000), recall: 0.877 (±0.069), for {'C': 10, 'gamma': 0.001, 'kernel': 'rbf'}
    precision: 1.000 (±0.000), recall: 0.877 (±0.069), for {'C': 100, 'gamma': 0.001, 'kernel': 'rbf'}
    precision: 1.000 (±0.000), recall: 0.877 (±0.069), for {'C': 1000, 'gamma': 0.001, 'kernel': 'rbf'}


    The selected final model is the fastest to predict out of the previously
    selected subset of best models based on precision and recall.
    Its scoring time is:

    mean_score_time                                          0.003467
    mean_test_recall                                         0.877206
    std_test_recall                                          0.069196
    mean_test_precision                                           1.0
    std_test_precision                                            0.0
    rank_test_recall                                                3
    rank_test_precision                                             1
    params                 {'C': 10, 'gamma': 0.001, 'kernel': 'rbf'}
    Name: 2, dtype: object


.. raw:: html

    <div class="output_subarea output_html rendered_html output_result">
    <style>#sk-container-id-30 {color: black;background-color: white;}#sk-container-id-30 pre{padding: 0;}#sk-container-id-30 div.sk-toggleable {background-color: white;}#sk-container-id-30 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-30 label.sk-toggleable__label-arrow:before {content: "▸";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-30 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-30 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-30 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-30 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-30 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-30 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: "▾";}#sk-container-id-30 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-30 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-30 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-30 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-30 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-30 div.sk-parallel-item::after {content: "";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-30 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-30 div.sk-serial::before {content: "";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-30 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-30 div.sk-item {position: relative;z-index: 1;}#sk-container-id-30 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-30 div.sk-item::before, #sk-container-id-30 div.sk-parallel-item::before {content: "";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-30 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-30 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-30 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-30 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-30 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-30 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-30 div.sk-label-container {text-align: center;}#sk-container-id-30 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-30 div.sk-text-repr-fallback {display: none;}</style><div id="sk-container-id-30" class="sk-top-container"><div class="sk-text-repr-fallback"><pre>GridSearchCV(estimator=SVC(),
                 param_grid=[{&#x27;C&#x27;: [1, 10, 100, 1000], &#x27;gamma&#x27;: [0.001, 0.0001],
                              &#x27;kernel&#x27;: [&#x27;rbf&#x27;]},
                             {&#x27;C&#x27;: [1, 10, 100, 1000], &#x27;kernel&#x27;: [&#x27;linear&#x27;]}],
                 refit=&lt;function refit_strategy at 0x7f6bec89f1a0&gt;,
                 scoring=[&#x27;precision&#x27;, &#x27;recall&#x27;])</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class="sk-container" hidden><div class="sk-item sk-dashed-wrapped"><div class="sk-label-container"><div class="sk-label sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="sk-estimator-id-98" type="checkbox" ><label for="sk-estimator-id-98" class="sk-toggleable__label sk-toggleable__label-arrow">GridSearchCV</label><div class="sk-toggleable__content"><pre>GridSearchCV(estimator=SVC(),
                 param_grid=[{&#x27;C&#x27;: [1, 10, 100, 1000], &#x27;gamma&#x27;: [0.001, 0.0001],
                              &#x27;kernel&#x27;: [&#x27;rbf&#x27;]},
                             {&#x27;C&#x27;: [1, 10, 100, 1000], &#x27;kernel&#x27;: [&#x27;linear&#x27;]}],
                 refit=&lt;function refit_strategy at 0x7f6bec89f1a0&gt;,
                 scoring=[&#x27;precision&#x27;, &#x27;recall&#x27;])</pre></div></div></div><div class="sk-parallel"><div class="sk-parallel-item"><div class="sk-item"><div class="sk-label-container"><div class="sk-label sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="sk-estimator-id-99" type="checkbox" ><label for="sk-estimator-id-99" class="sk-toggleable__label sk-toggleable__label-arrow">estimator: SVC</label><div class="sk-toggleable__content"><pre>SVC()</pre></div></div></div><div class="sk-serial"><div class="sk-item"><div class="sk-estimator sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="sk-estimator-id-100" type="checkbox" ><label for="sk-estimator-id-100" class="sk-toggleable__label sk-toggleable__label-arrow">SVC</label><div class="sk-toggleable__content"><pre>SVC()</pre></div></div></div></div></div></div></div></div></div></div>
    </div>
    <br />
    <br />

.. GENERATED FROM PYTHON SOURCE LINES 188-189

The parameters selected by the grid-search with our custom strategy are:

.. GENERATED FROM PYTHON SOURCE LINES 190-192

.. code-block:: default

    grid_search.best_params_





.. rst-class:: sphx-glr-script-out

 Out:

 .. code-block:: none


    {'C': 10, 'gamma': 0.001, 'kernel': 'rbf'}



.. GENERATED FROM PYTHON SOURCE LINES 193-199

Finally, we evaluate the fine-tuned model on the left-out evaluation set: the
`grid_search` object **has automatically been refit** on the full training
set with the parameters selected by our custom refit strategy.

We can use the classification report to compute standard classification
metrics on the left-out set:

.. GENERATED FROM PYTHON SOURCE LINES 200-205

.. code-block:: default

    from sklearn.metrics import classification_report

    y_pred = grid_search.predict(X_test)
    print(classification_report(y_test, y_pred))





.. rst-class:: sphx-glr-script-out

 Out:

 .. code-block:: none

                  precision    recall  f1-score   support

           False       0.99      1.00      0.99       807
            True       1.00      0.87      0.93        92

        accuracy                           0.99       899
       macro avg       0.99      0.93      0.96       899
    weighted avg       0.99      0.99      0.99       899





.. GENERATED FROM PYTHON SOURCE LINES 206-209

.. note::
   The problem is too easy: the hyperparameter plateau is too flat and the
   output model is the same for precision and recall with ties in quality.


.. rst-class:: sphx-glr-timing

   **Total running time of the script:** ( 0 minutes  9.453 seconds)


.. _sphx_glr_download_auto_examples_model_selection_plot_grid_search_digits.py:


.. only :: html

 .. container:: sphx-glr-footer
    :class: sphx-glr-footer-example



  .. container:: sphx-glr-download sphx-glr-download-python

     :download:`Download Python source code: plot_grid_search_digits.py <plot_grid_search_digits.py>`



  .. container:: sphx-glr-download sphx-glr-download-jupyter

     :download:`Download Jupyter notebook: plot_grid_search_digits.ipynb <plot_grid_search_digits.ipynb>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
