Source code for lightgbm.plotting

# coding: utf-8
# pylint: disable = C0103
"""Plotting Library."""
from __future__ import absolute_import

import warnings
from copy import deepcopy
from io import BytesIO

import numpy as np

from .basic import Booster
from .sklearn import LGBMModel


def check_not_tuple_of_2_elements(obj, obj_name='obj'):
    """check object is not tuple or does not have 2 elements"""
    if not isinstance(obj, tuple) or len(obj) != 2:
        raise TypeError('%s must be a tuple of 2 elements.' % obj_name)


[docs]def plot_importance(booster, ax=None, height=0.2, xlim=None, ylim=None, title='Feature importance', xlabel='Feature importance', ylabel='Features', importance_type='split', max_num_features=None, ignore_zero=True, figsize=None, grid=True, **kwargs): """Plot model feature importances. Parameters ---------- booster : Booster or LGBMModel Booster or LGBMModel instance ax : matplotlib Axes Target axes instance. If None, new figure and axes will be created. height : float Bar height, passed to ax.barh() xlim : tuple of 2 elements Tuple passed to axes.xlim() ylim : tuple of 2 elements Tuple passed to axes.ylim() title : str Axes title. Pass None to disable. xlabel : str X axis title label. Pass None to disable. ylabel : str Y axis title label. Pass None to disable. importance_type : str How the importance is calculated: "split" or "gain" "split" is the number of times a feature is used in a model "gain" is the total gain of splits which use the feature max_num_features : int Max number of top features displayed on plot. If None or smaller than 1, all features will be displayed. ignore_zero : bool Ignore features with zero importance figsize : tuple of 2 elements Figure size grid : bool Whether add grid for axes **kwargs : Other keywords passed to ax.barh() Returns ------- ax : matplotlib Axes """ try: import matplotlib.pyplot as plt except ImportError: raise ImportError('You must install matplotlib to plot importance.') if isinstance(booster, LGBMModel): booster = booster.booster_ elif not isinstance(booster, Booster): raise TypeError('booster must be Booster or LGBMModel.') importance = booster.feature_importance(importance_type=importance_type) feature_name = booster.feature_name() if not len(importance): raise ValueError('Booster feature_importances are empty.') tuples = sorted(zip(feature_name, importance), key=lambda x: x[1]) if ignore_zero: tuples = [x for x in tuples if x[1] > 0] if max_num_features is not None and max_num_features > 0: tuples = tuples[-max_num_features:] labels, values = zip(*tuples) if ax is None: if figsize is not None: check_not_tuple_of_2_elements(figsize, 'figsize') _, ax = plt.subplots(1, 1, figsize=figsize) ylocs = np.arange(len(values)) ax.barh(ylocs, values, align='center', height=height, **kwargs) for x, y in zip(values, ylocs): ax.text(x + 1, y, x, va='center') ax.set_yticks(ylocs) ax.set_yticklabels(labels) if xlim is not None: check_not_tuple_of_2_elements(xlim, 'xlim') else: xlim = (0, max(values) * 1.1) ax.set_xlim(xlim) if ylim is not None: check_not_tuple_of_2_elements(ylim, 'ylim') else: ylim = (-1, len(values)) ax.set_ylim(ylim) if title is not None: ax.set_title(title) if xlabel is not None: ax.set_xlabel(xlabel) if ylabel is not None: ax.set_ylabel(ylabel) ax.grid(grid) return ax
[docs]def plot_metric(booster, metric=None, dataset_names=None, ax=None, xlim=None, ylim=None, title='Metric during training', xlabel='Iterations', ylabel='auto', figsize=None, grid=True): """Plot one metric during training. Parameters ---------- booster : dict or LGBMModel Evals_result recorded by lightgbm.train() or LGBMModel instance metric : str or None The metric name to plot. Only one metric supported because different metrics have various scales. Pass None to pick `first` one (according to dict hashcode). dataset_names : None or list of str List of the dataset names to plot. Pass None to plot all datasets. ax : matplotlib Axes Target axes instance. If None, new figure and axes will be created. xlim : tuple of 2 elements Tuple passed to axes.xlim() ylim : tuple of 2 elements Tuple passed to axes.ylim() title : str Axes title. Pass None to disable. xlabel : str X axis title label. Pass None to disable. ylabel : str Y axis title label. Pass None to disable. Pass 'auto' to use `metric`. figsize : tuple of 2 elements Figure size grid : bool Whether add grid for axes Returns ------- ax : matplotlib Axes """ try: import matplotlib.pyplot as plt except ImportError: raise ImportError('You must install matplotlib to plot metric.') if isinstance(booster, LGBMModel): eval_results = deepcopy(booster.evals_result_) elif isinstance(booster, dict): eval_results = deepcopy(booster) else: raise TypeError('booster must be dict or LGBMModel.') num_data = len(eval_results) if not num_data: raise ValueError('eval results cannot be empty.') if ax is None: if figsize is not None: check_not_tuple_of_2_elements(figsize, 'figsize') _, ax = plt.subplots(1, 1, figsize=figsize) if dataset_names is None: dataset_names = iter(eval_results.keys()) elif not isinstance(dataset_names, (list, tuple, set)) or not dataset_names: raise ValueError('dataset_names should be iterable and cannot be empty') else: dataset_names = iter(dataset_names) name = next(dataset_names) # take one as sample metrics_for_one = eval_results[name] num_metric = len(metrics_for_one) if metric is None: if num_metric > 1: msg = """more than one metric available, picking one to plot.""" warnings.warn(msg, stacklevel=2) metric, results = metrics_for_one.popitem() else: if metric not in metrics_for_one: raise KeyError('No given metric in eval results.') results = metrics_for_one[metric] num_iteration, max_result, min_result = len(results), max(results), min(results) x_ = range(num_iteration) ax.plot(x_, results, label=name) for name in dataset_names: metrics_for_one = eval_results[name] results = metrics_for_one[metric] max_result, min_result = max(max(results), max_result), min(min(results), min_result) ax.plot(x_, results, label=name) ax.legend(loc='best') if xlim is not None: check_not_tuple_of_2_elements(xlim, 'xlim') else: xlim = (0, num_iteration) ax.set_xlim(xlim) if ylim is not None: check_not_tuple_of_2_elements(ylim, 'ylim') else: range_result = max_result - min_result ylim = (min_result - range_result * 0.2, max_result + range_result * 0.2) ax.set_ylim(ylim) if ylabel == 'auto': ylabel = metric if title is not None: ax.set_title(title) if xlabel is not None: ax.set_xlabel(xlabel) if ylabel is not None: ax.set_ylabel(ylabel) ax.grid(grid) return ax
def _to_graphviz(tree_info, show_info, feature_names, name=None, comment=None, filename=None, directory=None, format=None, engine=None, encoding=None, graph_attr=None, node_attr=None, edge_attr=None, body=None, strict=False): """Convert specified tree to graphviz instance. See: - http://graphviz.readthedocs.io/en/stable/api.html#digraph """ try: from graphviz import Digraph except ImportError: raise ImportError('You must install graphviz to plot tree.') def add(root, parent=None, decision=None): """recursively add node or edge""" if 'split_index' in root: # non-leaf name = 'split' + str(root['split_index']) if feature_names is not None: label = 'split_feature_name:' + str(feature_names[root['split_feature']]) else: label = 'split_feature_index:' + str(root['split_feature']) label += '\nthreshold:' + str(root['threshold']) for info in show_info: if info in {'split_gain', 'internal_value', 'internal_count'}: label += '\n' + info + ':' + str(root[info]) graph.node(name, label=label) if root['decision_type'] == 'no_greater': l_dec, r_dec = '<=', '>' elif root['decision_type'] == 'is': l_dec, r_dec = 'is', "isn't" else: raise ValueError('Invalid decision type in tree model.') add(root['left_child'], name, l_dec) add(root['right_child'], name, r_dec) else: # leaf name = 'leaf' + str(root['leaf_index']) label = 'leaf_index:' + str(root['leaf_index']) label += '\nleaf_value:' + str(root['leaf_value']) if 'leaf_count' in show_info: label += '\nleaf_count:' + str(root['leaf_count']) graph.node(name, label=label) if parent is not None: graph.edge(parent, name, decision) graph = Digraph(name=name, comment=comment, filename=filename, directory=directory, format=format, engine=engine, encoding=encoding, graph_attr=graph_attr, node_attr=node_attr, edge_attr=edge_attr, body=body, strict=strict) add(tree_info['tree_structure']) return graph
[docs]def create_tree_digraph(booster, tree_index=0, show_info=None, name=None, comment=None, filename=None, directory=None, format=None, engine=None, encoding=None, graph_attr=None, node_attr=None, edge_attr=None, body=None, strict=False): """Create a digraph of specified tree. See: - http://graphviz.readthedocs.io/en/stable/api.html#digraph Parameters ---------- booster : Booster, LGBMModel Booster or LGBMModel instance. tree_index : int, default 0 Specify tree index of target tree. show_info : list Information shows on nodes. options: 'split_gain', 'internal_value', 'internal_count' or 'leaf_count'. name : str Graph name used in the source code. comment : str Comment added to the first line of the source. filename : str Filename for saving the source (defaults to name + '.gv'). directory : str (Sub)directory for source saving and rendering. format : str Rendering output format ('pdf', 'png', ...). engine : str Layout command used ('dot', 'neato', ...). encoding : str Encoding for saving the source. graph_attr : dict Mapping of (attribute, value) pairs for the graph. node_attr : dict Mapping of (attribute, value) pairs set for all nodes. edge_attr : dict Mapping of (attribute, value) pairs set for all edges. body : list of str Iterable of lines to add to the graph body. strict : bool Iterable of lines to add to the graph body. Returns ------- graph : graphviz Digraph """ if isinstance(booster, LGBMModel): booster = booster.booster_ elif not isinstance(booster, Booster): raise TypeError('booster must be Booster or LGBMModel.') model = booster.dump_model() tree_infos = model['tree_info'] if 'feature_names' in model: feature_names = model['feature_names'] else: feature_names = None if tree_index < len(tree_infos): tree_info = tree_infos[tree_index] else: raise IndexError('tree_index is out of range.') if show_info is None: show_info = [] graph = _to_graphviz(tree_info, show_info, feature_names, name=name, comment=comment, filename=filename, directory=directory, format=format, engine=engine, encoding=encoding, graph_attr=graph_attr, node_attr=node_attr, edge_attr=edge_attr, body=body, strict=strict) return graph
[docs]def plot_tree(booster, ax=None, tree_index=0, figsize=None, graph_attr=None, node_attr=None, edge_attr=None, show_info=None): """Plot specified tree. Parameters ---------- booster : Booster, LGBMModel Booster or LGBMModel instance. ax : matplotlib Axes Target axes instance. If None, new figure and axes will be created. tree_index : int, default 0 Specify tree index of target tree. figsize : tuple of 2 elements Figure size. graph_attr : dict Mapping of (attribute, value) pairs for the graph. node_attr : dict Mapping of (attribute, value) pairs set for all nodes. edge_attr : dict Mapping of (attribute, value) pairs set for all edges. show_info : list Information shows on nodes. options: 'split_gain', 'internal_value', 'internal_count' or 'leaf_count'. Returns ------- ax : matplotlib Axes """ try: import matplotlib.pyplot as plt import matplotlib.image as image except ImportError: raise ImportError('You must install matplotlib to plot tree.') if ax is None: if figsize is not None: check_not_tuple_of_2_elements(figsize, 'figsize') _, ax = plt.subplots(1, 1, figsize=figsize) graph = create_tree_digraph( booster=booster, tree_index=tree_index, graph_attr=graph_attr, node_attr=node_attr, edge_attr=edge_attr, show_info=show_info ) s = BytesIO() s.write(graph.pipe(format='png')) s.seek(0) img = image.imread(s) ax.imshow(img) ax.axis('off') return ax