HiveBrain v1.2.0
Get Started
← Back to all entries
patternpythonMinor

Dijkstra's algorithm using collections.namedtuple

Submitted by: @import:stackexchange-codereview··
0
Viewed 0 times
dijkstranamedtuplealgorithmusingcollections

Problem

I decided to try out some named tuples by implementing Dijkstra's algorithm to find the cheapest routes in a file like this (where each line represents node_a is connected with node_b with a cost of n):

1 6 14
1 2 7
1 3 9
2 3 1
2 4 15
3 6 2
3 4 11
4 5 6
5 6 9


However, something that caught my attention is that some of the lines got really long:

```
import sys
from collections import namedtuple

INFINITY = 999999

class UndirectedGraph:
def __init__(self, node_list):
self.Node = namedtuple('Node', ['coming_from', 'cost'])
self.node_dict = self.get_nodes(node_list)
self.create_connections(node_list)
self.size = len(self.node_dict)

def get_nodes(self, node_list):
'''
Gets a list of tuples (node1, node2, weight_of_connection) and
distribute theses nodes through a dictionaire
'''
node_dict = {}
for line in node_list:
a_node, another_node = line[0], line[1]
node_dict[a_node] = []
node_dict[another_node] = []
return node_dict

def create_connections(self, raw_node_list):
'''
Creates the connection between the nodes in the nodes dict
'''
try:
for n in raw_node_list:
current_node, neighbor, cost = n
self.node_dict[current_node].append(self.get_new_node(neighbor, cost))
self.node_dict[neighbor].append(self.get_new_node(current_node, cost))
except:
print("General error: {}".format(sys.exc_info()[0]))
raise

def dijkstra(self, source):
'''
Applies the dijkstra algorithm for finding the shortest path to every
other node coming from source.
Returns a list containing the label of the early node and the cost of the
total path to the given node
'''
if source not in self.node_dict:
raise ValueError('Node informed does not exist')

Solution

One of the first things that I notice is that you define the Node as an instance variable. Better would be to define it at the top level:

Node = namedtuple('Node', ['coming_from', 'cost'])


Also, in Dijkstra's algorithm, the key point is not the nodes, but the edges. An edge is a tuple (source, destination, weight). But let's first see what we can clean up before tackling that.

Your UndirectedGraph takes a list of things. It claims to be a list of nodes, but it's actually a list of (node1, node2, weight_of_connection) lists: edges. So maybe it would make sense to define a namedtuple Edge:

Edge = namedtuple('Edge', ['source', 'target', 'cost'])


And in main:

edges = []
with open('dij.txt', 'r') as f:
    for line in f:
        source, target, cost = map(int, line.split())
        edges.append(Edge(source, target, cost))
my_graph = UndirectedGraph(node_list=edges)
...


Of course, we should rename node_list to edges, right? Because it's a list of edges. (Don't forget to change node_list to edges in the call as well. Maybe even remove the node_list= completely.

def __init__(self, edges):
    self.node_dict = self.get_nodes(edges)
    self.create_connections(edges)
    self.size = len(self.node_dict)


Now, let's take a look at self.get_nodes. It's supposed to construct a node_dict: a dictionary of nodes to lists (initially empty).

def get_nodes(self, node_list):
    '''
    Gets a list of tuples (node1, node2, weight_of_connection) and
    distribute theses nodes through a dictionaire 
    '''
    node_dict = {}
    for line in node_list:
        a_node, another_node = line[0], line[1]
        node_dict[a_node] = [] 
        node_dict[another_node] = []    
    return node_dict


Because of the Edge thing we used, we can actually write it as follows:

def get_nodes(self, edges):
    '''
    Gets a list of edges and
    distribute theses nodes through a dictionary
    '''
    node_dict = {}
    for edge in edges:
        node_dict[edge.source] = [] 
        node_dict[edge.target] = []    
    return node_dict


But even simpler, we could replace it with a defaultdict:

from collections import defaultdict

...
class UndirectedGraph(object):
    ...
    def get_nodes(self, edges):
        '''
        Gets a list of edges and
        distribute theses nodes through a dictionary
        '''
        return defaultdict(list)


At which point we can just write

def __init__(self, edges):
    self.node_dict = defaultdict(list)
    self.create_connections(edges)
    self.size = len(self.node_dict)


and remove the get_nodes method. Now let's take a look at the create_connections method.

def create_connections(self, raw_node_list):
    '''
    Creates the connection between the nodes in the nodes dict
    '''
    try:
        for n in raw_node_list:
            current_node, neighbor, cost = n
            self.node_dict[current_node].append(self.get_new_node(neighbor, cost))
            self.node_dict[neighbor].append(self.get_new_node(current_node, cost))
    except:
        print("General error: {}".format(sys.exc_info()[0]))
        raise


Why do you even have the try/except? It adds nothing of value. Also, the parameter is now a list of edges, instead of a raw_node_list. Another thing I'd suggest doing is using the attributes of our Edge class we just defined.

def create_connections(self, edges):
    '''
    Creates the connection between the nodes in the nodes dict
    '''
    for edge in edges:
        self.node_dict[edge.source].append(self.get_new_node(edge.target, edge.cost))
        self.node_dict[edge.target].append(self.get_new_node(edge.source, edge.cost))


Now I'd like to take a look at the algorithm itself, and I see several things I'd like to suggest, but don't have time for a full refactoring:

  • costs_array. Maybe make it a costs_dict, with a node as a key, and an integer as value. For instance, initialize it as costs_dict = {source: 0}



  • You're continuously constructing Node objects, while all you need is knowing the {node_name: path_cost} to continue.



I'm also a bit uncertain if your implementation of the algorithm is entirely correct, but I'd need to look that up a bit more.

Code Snippets

Node = namedtuple('Node', ['coming_from', 'cost'])
Edge = namedtuple('Edge', ['source', 'target', 'cost'])
edges = []
with open('dij.txt', 'r') as f:
    for line in f:
        source, target, cost = map(int, line.split())
        edges.append(Edge(source, target, cost))
my_graph = UndirectedGraph(node_list=edges)
...
def __init__(self, edges):
    self.node_dict = self.get_nodes(edges)
    self.create_connections(edges)
    self.size = len(self.node_dict)
def get_nodes(self, node_list):
    '''
    Gets a list of tuples (node1, node2, weight_of_connection) and
    distribute theses nodes through a dictionaire 
    '''
    node_dict = {}
    for line in node_list:
        a_node, another_node = line[0], line[1]
        node_dict[a_node] = [] 
        node_dict[another_node] = []    
    return node_dict

Context

StackExchange Code Review Q#118606, answer score: 5

Revisions (0)

No revisions yet.