niml.explainability package¶
Submodules¶
niml.explainability.explainability module¶
- class niml.explainability.explainability.Explainability(enc_obj=None, model_obj=None, input_data=None, label_col=0, pred_data=None, synapse_threshold=0.6, bin_threshold=0.9, prediction_idx=None, output_file=None)¶
Explainability object - given trained encoder and model objects and a prediction index, it decomposes the predicted class and displays the range of raw data inputs that define the label and subclass associated with the prediction index
- Parameters
enc_obj (niml.encoder) – The instantiated and trained encoder class object used to make the prediction being explained
model_obj (niml.model) – The instantiated and trained model class object used to make the prediction being explained
input_data (file-like object or list) – Only required if enc_obj is loaded from state. The original raw text data submitted to the niml.encoder encode method that generated the prediction being explained
label_col (int, optional) – Only required if enc_obj is loaded from state. If the input_data contains a column with labels, indicate the column number in this parameter, defaults to 0 if no value provided. If there is no label column, set this parameter to None
pred_data (file-like object or list) – Only required if model_obj is loaded from state. Output data returned from the model.get_subclass_predictions method containing a label and subclass saved to a file saved to file as a comma-separated 2 column data set, or stored in a python list
synapse_threshold (float) – A value between 0 and 1 that defines how closely a signature must match the pooled definitions to be considered a match
bin_threshold (float) – A value between 0 and 1 that defines how closely a decoded set of pooled definitions must match a set of encoded bins to be considered a match
prediction_idx (int) – A single integer from the list of predictions that is used to get prediction classification for explanation and the raw input data to plot with the decomposed explanation
output_file (basestring, optional) – A string for a file location to write the generated explainability plot. If not given, the plot will be displayed to standard out
- Raises
AttributeError – Raised if any of the following are not defined: enc_obj model_obj prediction_idx
ValueError – Raised if no input_data is received and the enc_obj provided does not contain raw column data
IndexError – Raised if the prediction_idx value is not present in the list of predicted values passed in for pred_data or contained in the list from the model_obj parameter
KeyError – Raised if any of the following lookups is unsuccesful: no prediction for the prediction_idx provided no signature for the label and subclass decoded from the prediction at the prediction_idx
- explain()¶
Primary method for performing explainability. Uses the inputs provided at initialization to decompose the prediction and determine the range of raw values associated with a label/subclass prediction
- Raises
AttributeError – Raised when there are no synapses that appear to match the predicted classification with enough certainty to pass the synapse_threshold parameter
- get_range_data_values()¶
After running explain on a signature, a call to this method will return a dictionary containing the raw numeric and percentage values for the feature ranges
- Returns
range_data_values – Returns a dictionary of dictionaries with 2 fixed top level keys (raw, pct) and sub-level keys for each feature. Inside the feature dictionaries is a list of tuples with the min and max raw or percentage values. An example is shown below:
{ 'raw': { 0: [ (5.40597016, 6.5029851), (6.54776122, 6.7268657), (7.10746272, 7.10746272), (7.1746269, 7.26417914), (7.33134332, 7.3985075) ], 1: [(2.8044778, 3.0731346)], 2: [(3.6597014, 4.9917909)], 3: [(1.14701498, 1.54179112)] }, 'pct': { 0: [ (0.26865672000000007, 0.6343283666666668), (0.6492537400000001, 0.7089552333333335), (0.8358209066666668, 0.8358209066666668), (0.8582089666666667, 0.8880597133333336), (0.9104477733333335, 0.9328358333333334) ], 1: [(-0.9975123333333333, -0.8482585555555555)], 2: [(-0.18437227450980387, 0.07682174509803924)], 3: [(-1.5012978347826087, -1.3296560347826085)] } }
- Return type
dict
- load_input_data(input_data=None, label_col=0)¶
Convenience method to allow users to load input data without instantiating a completely new explainability object.
- Parameters
input_data (file-like object or list) – The original raw text data submitted to the niml.encoder encode method that generated the prediction being explained
label_col (int, optional) – If the input_data contains a column with labels, indicate the column number in this parameter, defaults to 0 if no value provided. If there is no label column, set this parameter to None
- Raises
ValueError – Raised if no input_data is received or the input data is not in a format expected by the encoder class
- load_predictions(pred_data=None)¶
Convenience method to allow users to load prediction data without instantiating a completely new explainability object.
- Parameters
pred_data (file-like object or list) – Output data returned from the model.get_subclass_predictions method containing a label and subclass saved to a file saved to file as a comma-separated 2 column data set, or stored in a python list
- Raises
ValueError – Raised if no pred_data is received or the data recieved is not in the format expected of a file-like object or python list