diff --git a/cmd/internal/graph/common.go b/cmd/internal/graph/common.go index 9d242a1..2783b83 100644 --- a/cmd/internal/graph/common.go +++ b/cmd/internal/graph/common.go @@ -23,6 +23,7 @@ import ( type WeightedGraph interface { graph.Weighted + graph.WeightedMultigraph graph.WeightedMultigraphBuilder WeightedEdges() graph.WeightedEdges diff --git a/cmd/internal/graph/dot/graph.go b/cmd/internal/graph/dot/graph.go index 8291e76..0a55aaa 100644 --- a/cmd/internal/graph/dot/graph.go +++ b/cmd/internal/graph/dot/graph.go @@ -41,7 +41,7 @@ func NewDOTUndirectedGraph(weightAttr string) DOTWeightedGraph { } // NewLine returns a DOT-aware weighted line. -func (g *DOTWeightedGraph) NewLine(from, to graph.Node) graph.Line { +func (g DOTWeightedGraph) NewLine(from, to graph.Node) graph.Line { var defaultWeight float64 if g.WeightAttribute == "" { defaultWeight = 1 @@ -54,11 +54,11 @@ func (g *DOTWeightedGraph) NewLine(from, to graph.Node) graph.Line { } // NewNode returns a DOT-aware Node. -func (g *DOTWeightedGraph) NewNode() graph.Node { +func (g DOTWeightedGraph) NewNode() graph.Node { return &Node{Node: g.WeightedGraph.NewNode()} } // SetLine adds a DOT-aware weighted line to the graph. -func (g *DOTWeightedGraph) SetLine(e graph.Line) { +func (g DOTWeightedGraph) SetLine(e graph.Line) { g.WeightedGraph.SetWeightedLine(e.(*weightedLine)) } diff --git a/cmd/root.go b/cmd/root.go index 770dd51..b641dbe 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -25,6 +25,7 @@ import ( "github.com/spf13/cobra" "gonum.org/v1/gonum/graph/encoding/dot" + dotfmt "gonum.org/v1/gonum/graph/formats/dot" "gonum.org/v1/gonum/mat" ) @@ -60,7 +61,19 @@ func parse(cmd *cobra.Command, args []string) { panic(err) } - graph := idot.NewDOTDirectedGraph(weightAttr) + ast, err := dotfmt.ParseBytes(data) + if err != nil { + panic(err) + } + + first := ast.Graphs[0] + var graph idot.DOTWeightedGraph + if first.Directed { + graph = idot.NewDOTDirectedGraph(weightAttr) + } else { + graph = idot.NewDOTUndirectedGraph(weightAttr) + } + err = dot.UnmarshalMulti(data, graph) if err != nil { panic(err) @@ -74,7 +87,7 @@ func parse(cmd *cobra.Command, args []string) { outputMatrix(matrix) } -func orderedAdjMatrix(g *idot.DOTDirectedGraph) (*mat.Dense, error) { +func orderedAdjMatrix(g idot.DOTWeightedGraph) (*mat.Dense, error) { matrix := g.AdjacencyMatrix() if len(nodeOrder) == 0 { return matrix, nil