From 241cf1a9d04eb4842e625f17c2f4798cf7244abd Mon Sep 17 00:00:00 2001 From: filifa Date: Fri, 2 May 2025 02:04:10 -0400 Subject: [PATCH] use multigraph implementation to handle self loops --- cmd/internal/markov/absorbing.go | 37 ++++++++++++++++++++++---------- 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/cmd/internal/markov/absorbing.go b/cmd/internal/markov/absorbing.go index bf4f339..29c2e49 100644 --- a/cmd/internal/markov/absorbing.go +++ b/cmd/internal/markov/absorbing.go @@ -23,22 +23,23 @@ import ( "gonum.org/v1/gonum/graph" "gonum.org/v1/gonum/graph/encoding" + "gonum.org/v1/gonum/graph/multi" "gonum.org/v1/gonum/graph/simple" "gonum.org/v1/gonum/mat" ) type AbsorbingMarkovChain struct { - *simple.WeightedDirectedGraph + *multi.WeightedDirectedGraph } func NewAbsorbingMarkovChain() *AbsorbingMarkovChain { - return &AbsorbingMarkovChain{WeightedDirectedGraph: simple.NewWeightedDirectedGraph(math.NaN(), 0)} + return &AbsorbingMarkovChain{WeightedDirectedGraph: multi.NewWeightedDirectedGraph()} } func (g *AbsorbingMarkovChain) IsValid() bool { for nodes := g.Nodes(); nodes.Next(); { u := nodes.Node().(*node) - if g.outWeightSum(u) > 1 { + if g.outWeightSum(u) != 1 { return false } } @@ -52,7 +53,7 @@ func (g *AbsorbingMarkovChain) outWeightSum(u graph.Node) float64 { v := nodes.Node() e := g.WeightedEdge(u.ID(), v.ID()) if e != nil { - sum += e.(*weightedEdge).W + sum += e.Weight() } } @@ -60,9 +61,14 @@ func (g *AbsorbingMarkovChain) outWeightSum(u graph.Node) float64 { } func (g *AbsorbingMarkovChain) AdjacencyMatrix() mat.Matrix { - adj := simple.NewDirectedMatrix(g.Nodes().Len(), 0, math.NaN(), 0) + adj := simple.NewDirectedMatrix(g.Nodes().Len(), 0, 0, 0) for edges := g.WeightedEdges(); edges.Next(); { - adj.SetWeightedEdge(edges.WeightedEdge()) + e := edges.WeightedEdge() + if e.From() == e.To() { + continue + } + + adj.SetWeightedEdge(e) } a := mat.DenseCopyOf(adj.Matrix()) @@ -71,15 +77,19 @@ func (g *AbsorbingMarkovChain) AdjacencyMatrix() mat.Matrix { for i := 0; nodes.Next(); i++ { id := nodes.Node().ID() u := g.Node(id).(*node) - a.Set(i, i, 1-g.outWeightSum(u)) + + e := g.WeightedEdge(u.ID(), u.ID()) + if e != nil { + a.Set(i, i, e.Weight()) + } } return a } func (g *AbsorbingMarkovChain) NewEdge(from, to graph.Node) graph.Edge { - e := g.WeightedDirectedGraph.NewWeightedEdge(from, to, math.NaN()).(simple.WeightedEdge) - return &weightedEdge{WeightedEdge: e} + e := g.WeightedDirectedGraph.NewWeightedLine(from, to, math.NaN()).(multi.WeightedLine) + return &weightedEdge{WeightedLine: e} } func (g *AbsorbingMarkovChain) NewNode() graph.Node { @@ -87,11 +97,16 @@ func (g *AbsorbingMarkovChain) NewNode() graph.Node { } func (g *AbsorbingMarkovChain) SetEdge(e graph.Edge) { - g.WeightedDirectedGraph.SetWeightedEdge(e.(*weightedEdge)) + g.WeightedDirectedGraph.SetWeightedLine(e.(*weightedEdge)) } type weightedEdge struct { - simple.WeightedEdge + multi.WeightedLine +} + +func (e *weightedEdge) ReversedEdge() graph.Edge { + revLine := multi.WeightedLine{F: e.T, T: e.F, W: e.W} + return &weightedEdge{WeightedLine: revLine} } func (e *weightedEdge) SetAttribute(attr encoding.Attribute) error {