Skip to content

Commit

Permalink
Monad instance for Vector and Matrix (#607)
Browse files Browse the repository at this point in the history
  • Loading branch information
cannorin authored Sep 22, 2024
1 parent f2e49ba commit 43050e6
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 53 deletions.
33 changes: 31 additions & 2 deletions src/FSharpPlus.TypeLevel/Data/Matrix.fs
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,17 @@ module Vector =

let inline apply (f: Vector<'a -> 'b, 'n>) (v: Vector<'a, 'n>) : Vector<'b, 'n> = map2 id f v

/// <description>
/// Converts the vector of vectors to a square matrix and returns its diagonal.
/// </description>
/// <seealso href="https://stackoverflow.com/questions/5802628/monad-instance-of-a-number-parameterised-vector" />
[<MethodImpl(MethodImplOptions.AggressiveInlining)>]
let join (vv: Vector<Vector<'a, 'n>, 'n>): Vector<'a, 'n> =
{ Items = Array.init (Array.length vv.Items) (fun i -> vv.Items.[i].Items.[i]) }

let inline bind (f: 'a -> Vector<'b, 'n>) (v: Vector<'a, 'n>) : Vector<'b, 'n> =
v |> map f |> join

let inline norm (v: Vector< ^a, ^n >) : ^a =
v |> toArray |> Array.sumBy (fun x -> x * x) |> sqrt
let inline maximumNorm (v: Vector< ^a, ^n >) : ^a =
Expand Down Expand Up @@ -327,6 +338,20 @@ module Matrix =
for j = 0 to Array2D.length2 m1.Items - 1 do
f i j m1.Items.[i, j] m2.Items.[i, j]

let inline apply (f: Matrix<'a -> 'b, 'm, 'n>) (m: Matrix<'a, 'm, 'n>) : Matrix<'b, 'm, 'n> = map2 id f m

/// <description>
/// Converts the matrix of matrices to a 3D cube matrix and returns its diagonal.
/// </description>
/// <seealso href="https://stackoverflow.com/questions/5802628/monad-instance-of-a-number-parameterised-vector" />
[<MethodImpl(MethodImplOptions.AggressiveInlining)>]
let join (m: Matrix<Matrix<'a, 'm, 'n>, 'm, 'n>) : Matrix<'a, 'm, 'n> =
{ Items =
Array2D.init (Array2D.length1 m.Items) (Array2D.length2 m.Items)
(fun i j -> m.Items.[i, j].Items.[i, j] ) }

let inline bind (f: 'a -> Matrix<'b, 'm, 'n>) (m: Matrix<'a, 'm, 'n>) : Matrix<'b, 'm, 'n> = m |> map f |> join

let inline rowLength (_: Matrix<'a, 'm, 'n>) : 'm = Singleton<'m>
let inline colLength (_: Matrix<'a, 'm, 'n>) : 'n = Singleton<'n>
let inline rowLength' (_: Matrix<'a, ^m, 'n>) : int = RuntimeValue (Singleton< ^m >)
Expand Down Expand Up @@ -571,8 +596,10 @@ type Matrix<'Item, 'Row, 'Column> with

static member inline Return (x: 'x) : Matrix<'x, 'm, 'n> = Matrix.replicate Singleton Singleton x
static member inline Pure (x: 'x) : Matrix<'x, 'm, 'n> = Matrix.replicate Singleton Singleton x
static member inline ( <*> ) (f: Matrix<'x -> 'y, 'm, 'n>, x: Matrix<'x, 'm, 'n>) = Matrix.map2 id f x
static member inline ( <.> ) (f: Matrix<'x -> 'y, 'm, 'n>, x: Matrix<'x, 'm, 'n>) = Matrix.map2 id f x
static member inline ( <*> ) (f: Matrix<'x -> 'y, 'm, 'n>, x: Matrix<'x, 'm, 'n>) = Matrix.apply f x
static member inline ( <.> ) (f: Matrix<'x -> 'y, 'm, 'n>, x: Matrix<'x, 'm, 'n>) = Matrix.apply f x
static member inline Join (x: Matrix<Matrix<'x, 'm, 'n>, 'm, 'n>) = Matrix.join x
static member inline ( >>= ) (x: Matrix<'x, 'm, 'n>, f: 'x -> Matrix<'y, 'm, 'n>) = Matrix.bind f x
static member inline get_Zero () : Matrix<'a, 'm, 'n> = Matrix.zero
static member inline ( + ) (m1, m2) = Matrix.map2 (+) m1 m2
static member inline ( - ) (m1, m2) = Matrix.map2 (-) m1 m2
Expand Down Expand Up @@ -607,6 +634,8 @@ type Vector<'Item, 'Length> with
static member inline Pure (x: 'x) : Vector<'x, 'n> = Vector.replicate Singleton x
static member inline ( <*> ) (f: Vector<'x -> 'y, 'n>, x: Vector<'x, 'n>) : Vector<'y, 'n> = Vector.apply f x
static member inline ( <.> ) (f: Vector<'x -> 'y, 'n>, x: Vector<'x, 'n>) : Vector<'y, 'n> = Vector.apply f x
static member inline Join (x: Vector<Vector<'x, 'n>, 'n>) : Vector<'x, 'n> = Vector.join x
static member inline ( >>= ) (x: Vector<'x, 'n>, f: 'x -> Vector<'y, 'n>) = Vector.bind f x

[<EditorBrowsable(EditorBrowsableState.Never)>]
static member inline Zip (x, y) = Vector.zip x y
Expand Down
1 change: 1 addition & 0 deletions tests/FSharpPlus.Tests/FSharpPlus.Tests.fsproj
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
<Compile Include="Lens.fs" />
<Compile Include="Extensions.fs" />
<Compile Include="BifoldableTests.fs" />
<Compile Include="Matrix.fs" />
<Compile Include="TypeLevel.fs" />
</ItemGroup>
<ItemGroup>
Expand Down
98 changes: 98 additions & 0 deletions tests/FSharpPlus.Tests/Matrix.fs
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
namespace FSharpPlus.Tests

open System
open NUnit.Framework
open Helpers

open FSharpPlus
open FSharpPlus.Data
open FSharpPlus.TypeLevel

module VectorTests =
[<Test>]
let constructorAndDeconstructorWorks() =
let v1 = vector (1,2,3,4,5)
let v2 = vector (1,2,3,4,5,6,7,8,9,0,1,2,3,4,5)
let (Vector(_,_,_,_,_)) = v1
let (Vector(_,_,_,_,_,_,_,_,_,_,_,_,_,_,_)) = v2
()

[<Test>]
let applicativeWorks() =
let v = vector ((fun i -> i + 1), (fun i -> i * 2))
let u = vector (2, 3)
let vu = v <*> u
NUnit.Framework.Assert.IsInstanceOf<Option<Vector<int,S<S<Z>>>>> (Some vu)
CollectionAssert.AreEqual ([|3; 6|], Vector.toArray vu)

[<Test>]
let satisfiesApplicativeLaws() =
let u = vector ((fun i -> i - 1), (fun i -> i * 2))
let v = vector ((fun i -> i + 1), (fun i -> i * 3))
let w = vector (1, 1)

areEqual (result id <*> v) v
areEqual (result (<<) <*> u <*> v <*> w) (u <*> (v <*> w))
areEqual (result 2) ((result (fun i -> i + 1) : Vector<int -> int, S<S<Z>>>) <*> result 1)
areEqual (u <*> result 1) (result ((|>) 1) <*> u)

[<Test>]
let satisfiesMonadLaws() =
let k = fun (a: int) -> vector (a - 1, a * 2)
let h = fun (a: int) -> vector (a + 1, a * 3)
let m = vector (1, 2)

areEqual (result 2 >>= k) (k 2)
areEqual (m >>= result) m
areEqual (m >>= (fun x -> k x >>= h)) ((m >>= k) >>= h)

module MatrixTests =
[<Test>]
let constructorAndDeconstructorWorks() =
let m1 =
matrix (
(1,0,0,0),
(0,1,0,0),
(0,0,1,0)
)
let m2 =
matrix (
(1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0),
(0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0),
(0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0),
(0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0),
(0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0),
(0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0),
(0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0),
(0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0)
)
let (Matrix(_x1,_x2,_x3)) = m1
let (Matrix(_y1: int*int*int*int*int*int*int*int*int*int*int*int*int*int*int*int,_y2,_y3,_y4,_y5,_y6,_y7,_y8)) = m2
()

[<Test>]
let satisfiesApplicativeLaws() =
let u = matrix (
((fun i -> i - 1), (fun i -> i * 2)),
((fun i -> i + 1), (fun i -> i * 3))
)
let v = matrix (
((fun i -> i - 2), (fun i -> i * 5)),
((fun i -> i + 2), (fun i -> i * 7))
)
let w = matrix ((1, 1), (1, 2))

areEqual (result id <*> v) v
areEqual (result (<<) <*> u <*> v <*> w) (u <*> (v <*> w))
areEqual ((result (fun i -> i + 1) : Matrix<int -> int, S<S<Z>>, S<S<Z>>>) <*> result 1) (result 2)
areEqual (u <*> result 1) (result ((|>) 1) <*> u)

[<Test>]
let satisfiesMonadLaws() =
let k = fun (a: int) -> matrix ((a - 1, a * 2), (a + 1, a * 3))
let h = fun (a: int) -> matrix ((a - 2, a * 5), (a + 2, a * 7))
let m = matrix ((1, 1), (1, 2))

areEqual (result 2 >>= k) (k 2)
areEqual (m >>= result) m
areEqual (m >>= (fun x -> k x >>= h)) ((m >>= k) >>= h)
52 changes: 1 addition & 51 deletions tests/FSharpPlus.Tests/TypeLevel.fs
Original file line number Diff line number Diff line change
Expand Up @@ -150,38 +150,8 @@ module NatTests =
Assert (g2 =^ S(S(S(S(S(S Z))))))


open FSharpPlus.Data

module MatrixTests =
[<Test>]
let matrixTests =
let v1 = vector (1,2,3,4,5)
let v2 = vector (1,2,3,4,5,6,7,8,9,0,1,2,3,4,5)
let (Vector(_,_,_,_,_)) = v1
let (Vector(_,_,_,_,_,_,_,_,_,_,_,_,_,_,_)) = v2

let m1 =
matrix (
(1,0,0,0),
(0,1,0,0),
(0,0,1,0)
)
let m2 =
matrix (
(1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0),
(0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0),
(0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0),
(0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0),
(0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0),
(0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0),
(0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0),
(0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0)
)
let (Matrix(_x1,_x2,_x3)) = m1
let (Matrix(_y1: int*int*int*int*int*int*int*int*int*int*int*int*int*int*int*int,_y2,_y3,_y4,_y5,_y6,_y7,_y8)) = m2
()

open Helpers
open FSharpPlus.Data

module TypeProviderTests =
type ``0`` = TypeNat<0>
Expand All @@ -206,23 +176,3 @@ module TypeProviderTests =
Assert (Matrix.colLength row1 =^ (Z |> S |> S |> S))
areEqual 5 (Matrix.get Z (S Z) row1)
areEqual [3; 6; 9] (Vector.toList col2)

module TestFunctors1 =
[<Test>]
let applicativeOperatorWorks() =
let v = vector ((fun i -> i + 1), (fun i -> i * 2))
let u = vector (2, 3)
let vu = v <*> u
NUnit.Framework.Assert.IsInstanceOf<Option<Vector<int,S<S<Z>>>>> (Some vu)
CollectionAssert.AreEqual ([|3; 6|], Vector.toArray vu)

module TestFunctors2 =
open FSharpPlus

[<Test>]
let applicativeWorksWithoutSubsumption() =
let v = vector ((fun i -> i + 1), (fun i -> i * 2))
let u = vector (2, 3)
let vu = v <*> u
NUnit.Framework.Assert.IsInstanceOf<Option<Vector<int,S<S<Z>>>>> (Some vu)
CollectionAssert.AreEqual ([|3; 6|], Vector.toArray vu)

0 comments on commit 43050e6

Please sign in to comment.