Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sensitivity, specificity ... etc #36

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 33 additions & 11 deletions docs/source/perfeval.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ Classification Performance
0 2 1
0 0 2

julia> C ./ sum(C, 2) # normalize per class
julia> C ./ sum(C, 2) # normalize per class
3x3 Array{Float64,2}:
0.666667 0.333333 0.0
0.666667 0.333333 0.0
0.0 0.666667 0.333333
0.0 0.0 1.0

Expand All @@ -52,7 +52,7 @@ Hit rate (for retrieval tasks)

.. function:: hitrate(gt, ranklist, k)

Compute the hitrate of rank ``k`` for a ranked list of predictions given by ``ranklist`` w.r.t. the ground truths given in ``gt``.
Compute the hitrate of rank ``k`` for a ranked list of predictions given by ``ranklist`` w.r.t. the ground truths given in ``gt``.

Particularly, if ``gt[i]`` is contained in ``ranklist[1:k, i]``, then the prediction for the ``i``-th sample is said to be *hit within rank ``k``*. The hitrate of rank ``k`` is the fraction of predictions that hit within rank ``k``.

Expand Down Expand Up @@ -111,7 +111,7 @@ One can compute a variety of performance measurements from an instance of ``ROCN
the fraction of negative samples correctly predicted as negative, defined as ``r.tn / r.n``

.. function:: false_positive_rate(r)

the fraction of negative samples incorrectly predicted as positive, defined as ``r.fp / r.n``

.. function:: false_negative_rate(r)
Expand All @@ -122,10 +122,34 @@ One can compute a variety of performance measurements from an instance of ``ROCN

Equivalent to ``true_positive_rate(r)``.

.. function:: sensitivity(r)

Equivalent to ``true_positive_rate(r)``, for semantic convenience.

.. function:: specificity(r)

Equivalent to ``true_positive_rate(r)``, for semantic convenience.

.. function:: precision(r)

the fraction of positive predictions that are correct, defined as ``r.tp / (r.tp + r.fp)``.

.. function:: positive_predictive_value(r)

Equivalent to ``precision(r)``.

.. function:: negative_predictive_value(r)

the fraction of negative predictions that are correct, defined as ``r.tn / (r.tn + r.fn)``.

.. function:: false_discovery_rate(r)

the fraction of positive predictions that are incorrect, defined as ``r.fp / (r.tp + r.fp)``.

.. function:: accuracy(r)

the fraction of all predictions that are correct, defined as ``(r.tp + r.tn) / (r.p + r.n)``.

.. function:: f1score(r)

the harmonic mean of ``recall(r)`` and ``precision(r)``.
Expand All @@ -141,7 +165,7 @@ The package provides a function ``roc`` to compute an instance of ``ROCNums`` or

.. function:: roc(gt, scores, thres[, ord])

Compute an ROC instance or an ROC curve (a vector of ``ROC`` instances), based on given scores and a threshold ``thres``.
Compute an ROC instance or an ROC curve (a vector of ``ROC`` instances), based on given scores and a threshold ``thres``.

Prediction will be made as follows:

Expand All @@ -152,15 +176,15 @@ The package provides a function ``roc`` to compute an instance of ``ROCNums`` or

**Returns:**

- When ``thres`` is a single number, it produces a single ``ROCNums`` instance;
- When ``thres`` is a vector, it produces a vector of ``ROCNums`` instances.
- When ``thres`` is a single number, it produces a single ``ROCNums`` instance;
- When ``thres`` is a vector, it produces a vector of ``ROCNums`` instances.

**Note:** Jointly evaluating an ROC curve for multiple thresholds is generally much faster than evaluating for them individually.


.. function:: roc(gt, (preds, scores), thres[, ord])

Compute an ROC instance or an ROC curve (a vector of ``ROC`` instances) for multi-class classification, based on given predictions, scores and a threshold ``thres``.
Compute an ROC instance or an ROC curve (a vector of ``ROC`` instances) for multi-class classification, based on given predictions, scores and a threshold ``thres``.

Prediction is made as follows:

Expand All @@ -172,7 +196,7 @@ The package provides a function ``roc`` to compute an instance of ``ROCNums`` or
**Returns:**

- When ``thres`` is a single number, it produces a single ``ROCNums`` instance.
- When ``thres`` is a vector, it produces an ROC curve (a vector of ``ROCNums`` instances).
- When ``thres`` is a vector, it produces an ROC curve (a vector of ``ROCNums`` instances).

**Note:** Jointly evaluating an ROC curve for multiple thresholds is generally much faster than evaluating for them individually.

Expand All @@ -199,5 +223,3 @@ The package provides a function ``roc`` to compute an instance of ``ROCNums`` or
.. function:: roc(gt, (preds, scores))

Equivalent to ``roc(gt, (preds, scores), 100, Forward)``.


7 changes: 6 additions & 1 deletion src/MLBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,13 @@ module MLBase
false_positive_rate, # rate of false positives
false_negative_rate, # rate of false negatives
recall, # recall computed from ROCNums
sensitivity, # sensitivity computed from ROCNums
specificity, # specificity computed from ROCNums
precision, # precision computed from ROCNums
positive_predictive_value, # positive predictive value computed from ROCNums
negative_predictive_value, # negative predictive value computed from ROCNums
false_discovery_rate, # false discovery rate computed from ROCNums
accuracy, # accuracy computed from ROCNums
f1score, # F1-score computed from ROCNums

# modeltune
Expand All @@ -79,4 +85,3 @@ module MLBase

include("deprecates.jl")
end

27 changes: 16 additions & 11 deletions src/perfeval.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ function counthits(gt::IntegerVector, rklst::IntegerMatrix, ks::IntegerVector)
end


hitrate(gt::IntegerVector, rklst::IntegerMatrix, k::Integer) =
hitrate(gt::IntegerVector, rklst::IntegerMatrix, k::Integer) =
(counthits(gt, rklst, k) / length(gt))::Float64

function hitrates(gt::IntegerVector, rklst::IntegerMatrix, ks::IntegerVector)
Expand Down Expand Up @@ -114,8 +114,14 @@ false_positive_rate(x::ROCNums) = x.fp / x.n
false_negative_rate(x::ROCNums) = x.fn / x.p

recall(x::ROCNums) = true_positive_rate(x)
sensitivity(x::ROCNums) = true_positive_rate(x)
specificity(x::ROCNums) = true_negative_rate(x)
precision(x::ROCNums) = x.tp / (x.tp + x.fp)
positive_predictive_value(x::ROCNums) = precision(x)
negative_predictive_value(x::ROCNums) = x.tn / (x.tn + x.fn)
false_discovery_rate(x::ROCNums) = x.fp / (x.tp + x.fp)

accuracy(x::ROCNums) = (x.tp + x.tn) / (x.p + x.n)
f1score(x::ROCNums) = (tp2 = x.tp + x.tp; tp2 / (tp2 + x.fp + x.fn) )


Expand Down Expand Up @@ -178,7 +184,7 @@ length(v::BinaryThresPredVec) = length(v.scores)
getindex(v::BinaryThresPredVec, i::Integer) = !lt(v.ord, v.scores[i], v.thres)

# compute roc numbers based on scores & threshold
roc(gt::IntegerVector, scores::RealVector, t::Real, ord::Ordering) =
roc(gt::IntegerVector, scores::RealVector, t::Real, ord::Ordering) =
_roc(gt, BinaryThresPredVec(scores, t, ord))

roc(gt::IntegerVector, scores::RealVector, thres::Real) =
Expand Down Expand Up @@ -211,7 +217,7 @@ length(v::ThresPredVec) = length(v.preds)
getindex(v::ThresPredVec, i::Integer) = ifelse(lt(v.ord, v.scores[i], v.thres), 0, v.preds[i])

# compute roc numbers based on predictions & scores & threshold
roc{PV<:IntegerVector,SV<:RealVector}(gt::IntegerVector, preds::@compat(Tuple{PV,SV}), t::Real, ord::Ordering) =
roc{PV<:IntegerVector,SV<:RealVector}(gt::IntegerVector, preds::@compat(Tuple{PV,SV}), t::Real, ord::Ordering) =
_roc(gt, ThresPredVec(preds..., t, ord))

roc{PV<:IntegerVector,SV<:RealVector}(gt::IntegerVector, preds::@compat(Tuple{PV,SV}), thres::Real) =
Expand Down Expand Up @@ -246,10 +252,10 @@ end

find_thresbin(x::Real, thresholds::RealVector) = find_thresbin(x, thresholds, Forward)

lin_thresholds(scores::RealArray, n::Integer, ord::ForwardOrdering) =
lin_thresholds(scores::RealArray, n::Integer, ord::ForwardOrdering) =
((s0, s1) = extrema(scores); intv = (s1 - s0) / (n-1); s0:intv:s1)

lin_thresholds(scores::RealArray, n::Integer, ord::ReverseOrdering{ForwardOrdering}) =
lin_thresholds(scores::RealArray, n::Integer, ord::ReverseOrdering{ForwardOrdering}) =
((s0, s1) = extrema(scores); intv = (s0 - s1) / (n-1); s1:intv:s0)

# roc for binary predictions
Expand Down Expand Up @@ -293,7 +299,7 @@ end

roc(gt::IntegerVector, scores::RealVector, thresholds::RealVector) = roc(gt, scores, thresholds, Forward)

roc(gt::IntegerVector, scores::RealVector, n::Integer, ord::Ordering) =
roc(gt::IntegerVector, scores::RealVector, n::Integer, ord::Ordering) =
roc(gt, scores, lin_thresholds(scores, n, ord), ord)

roc(gt::IntegerVector, scores::RealVector, n::Integer) = roc(gt, scores, n, Forward)
Expand Down Expand Up @@ -357,15 +363,14 @@ end
roc{PV<:IntegerVector, SV<:RealVector}(gt::IntegerVector, preds::@compat(Tuple{PV,SV}), thresholds::RealVector) =
roc(gt, preds, thresholds, Forward)

roc{PV<:IntegerVector, SV<:RealVector}(gt::IntegerVector, preds::@compat(Tuple{PV,SV}), n::Integer, ord::Ordering) =
roc{PV<:IntegerVector, SV<:RealVector}(gt::IntegerVector, preds::@compat(Tuple{PV,SV}), n::Integer, ord::Ordering) =
roc(gt, preds, lin_thresholds(preds[2],n,ord), ord)

roc{PV<:IntegerVector, SV<:RealVector}(gt::IntegerVector, preds::@compat(Tuple{PV,SV}), n::Integer) =
roc{PV<:IntegerVector, SV<:RealVector}(gt::IntegerVector, preds::@compat(Tuple{PV,SV}), n::Integer) =
roc(gt, preds, n, Forward)

roc{PV<:IntegerVector, SV<:RealVector}(gt::IntegerVector, preds::@compat(Tuple{PV,SV}), ord::Ordering) =
roc{PV<:IntegerVector, SV<:RealVector}(gt::IntegerVector, preds::@compat(Tuple{PV,SV}), ord::Ordering) =
roc(gt, preds, 100, ord)

roc{PV<:IntegerVector, SV<:RealVector}(gt::IntegerVector, preds::@compat(Tuple{PV,SV})) =
roc{PV<:IntegerVector, SV<:RealVector}(gt::IntegerVector, preds::@compat(Tuple{PV,SV})) =
roc(gt, preds, Forward)

11 changes: 8 additions & 3 deletions test/perfeval.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ using MLBase
using Base.Test

import StatsBase
import StatsBase: harmmean
import StatsBase: harmmean

## correctrate & errorrate

Expand All @@ -29,7 +29,7 @@ rs = [1 2 2 1 3 2 1 1 3 3;

@test counthits(gt, rs, 1:3) == [3, 8, 8]
@test counthits(gt, rs, [2, 4]) == [8, 10]
@test counthits(gt, rs, 1:2:5) == [3, 8, 10]
@test counthits(gt, rs, 1:2:5) == [3, 8, 10]

@test_approx_eq [hitrate(gt, rs, k) for k=1:5] [0.3, 0.8, 0.8, 1.0, 1.0]
@test_approx_eq hitrates(gt, rs, 1:3) [0.3, 0.8, 0.8]
Expand Down Expand Up @@ -57,7 +57,13 @@ r = ROCNums{Int}(
@test false_negative_rate(r) == 0.20

@test recall(r) == 0.80
@test sensitivity(r) == 0.80
@test specificity(r) == 0.75
@test precision(r) == (8/13)
@test positive_predictive_value(r) == (8/13)
@test negative_predictive_value(r) == (15/17)
@test false_discovery_rate(r) == (5/13)
@test accuracy(r) == (23/30)
@test_approx_eq f1score(r) harmmean([recall(r), precision(r)])


Expand Down Expand Up @@ -126,4 +132,3 @@ r100 = roc(gt, (pr, ss), 1.00)
@test roc(gt, (pr, ss), 0.0:0.25:1.0) == [r00, r25, r50, r75, r100]
# @test roc(gt, (pr, ss), 7) == roc(gt, (pr, ss), 0.2:0.1:0.8, Forward)
@test roc(gt, (pr, ss)) == roc(gt, (pr, ss), MLBase.lin_thresholds([0.2, 0.8], 100, Forward))