59 lines
2.1 KiB
Plaintext
59 lines
2.1 KiB
Plaintext
module horse/prob/pmf
|
|
|
|
import std/core/list
|
|
|
|
// Discrete-support probability distribution implemented as a list with the invariant
|
|
// that support is always given in increasing order.
|
|
pub type pmf<s, v>
|
|
Event(s: s, v: v, next: pmf<s, v>)
|
|
End
|
|
|
|
// Add an independent event to the distribution.
|
|
pub fun add(p: pmf<s, v>, s: s, v: v, ?s/cmp: (a: s, b: s) -> order, ?v/(+): (new: v, old: v) -> e v): e pmf<s, v>
|
|
match p
|
|
End -> Event(s, v, End)
|
|
Event(s', v', next) -> match s.cmp(s')
|
|
Lt -> Event(s, v, Event(s', v', next))
|
|
Eq -> Event(s, v + v', next)
|
|
Gt -> Event(s', v', add(next, s, v))
|
|
|
|
// Replace an event in the distribution.
|
|
pub inline fun set(p: pmf<s, v>, s: s, v: v, ?s/cmp: (a: s, b: s) -> order): e pmf<s, v>
|
|
p.add(s, v, cmp, fn(new, old) new)
|
|
|
|
// Construct a pmf from a list of (support, value) entries.
|
|
pub fun list/pmf(l: list<(s, v)>, ?s/cmp: (a: s, b: s) -> order, ?v/(+): (new: v, old: v) -> e v): e pmf<s, v>
|
|
l.foldl(End) fn(p, (s, v)) p.add(s, v)
|
|
|
|
// Fold over the entries of the distribution.
|
|
pub tail fun foldl(p: pmf<s, v>, init: a, f: (a, s, v) -> e a): e a
|
|
match p
|
|
End -> init
|
|
Event(s, v, next) -> foldl(next, f(init, s, v), f)
|
|
|
|
// Convert the distribution to a list of entries.
|
|
pub fun pmf/list(p: pmf<s, v>): list<(s, v)>
|
|
p.foldl(Nil) fn(l, s, v) Cons((s, v), l)
|
|
|
|
// Distribution of cooccurrence of two events described by their distributions.
|
|
pub fun (*)(a: pmf<s, v>, b: pmf<s, v>, ?s/cmp: (a: s, b: s) -> order, ?v/(*): (a: v, b: v) -> e v): e pmf<s, v>
|
|
match a
|
|
End -> End
|
|
Event(sa, va, nexta) -> match b
|
|
End -> End
|
|
Event(sb, vb, nextb) -> match sa.cmp(sb)
|
|
Lt -> nexta * b
|
|
Eq -> Event(sa, va * vb, nexta * nextb)
|
|
Gt -> a * nextb
|
|
|
|
// Distribution of occurrence of at least one of two events described by their distributions.
|
|
pub fun (+)(a: pmf<s, v>, b: pmf<s, v>, ?s/cmp: (a: s, b: s) -> order, ?v/(+): (a: v, b: v) -> e v): e pmf<s, v>
|
|
match a
|
|
End -> b
|
|
Event(sa, va, nexta) -> match b
|
|
End -> a
|
|
Event(sb, vb, nextb) -> match sa.cmp(sb)
|
|
Lt -> Event(sa, va, nexta + b)
|
|
Eq -> Event(sa, va + vb, nexta + nextb)
|
|
Gt -> Event(sb, vb, a + nextb)
|