loading...
Cover image for Dijkstra's algorithm in python: algorithms for beginners

Dijkstra Python Dijkstra's algorithm in python: algorithms for beginners

mxl profile image Maria Boldyreva ・5 min read

Photo by Ishan @seefromthesky on Unsplash

Dijkstra's algorithm can find for you the shortest path between two nodes on a graph. It's a must-know for any programmer. There are nice gifs and history in its Wikipedia page.

In this post I'll use the time-tested implementation from Rosetta Code changed just a bit for being able to process weighted and unweighted graph data, also, we'll be able to edit the graph on the fly. I'll explain the code block by block.

The algorithm

The algorithm is pretty simple. Dijkstra created it in 20 minutes, now you can learn to code it in the same time.

  1. Mark all nodes unvisited and store them.
  2. Set the distance to zero for our initial node and to infinity for other nodes.
  3. Select the unvisited node with the smallest distance, it's current node now.
  4. Find unvisited neighbors for the current node and calculate their distances through the current node. Compare the newly calculated distance to the assigned and save the smaller one. For example, if the node A has a distance of 6, and the A-B edge has length 2, then the distance to B through A will be 6 + 2 = 8. If B was previously marked with a distance greater than 8 then change it to 8.
  5. Mark the current node as visited and remove it from the unvisited set.
  6. Stop, if the destination node has been visited (when planning a route between two specific nodes) or if the smallest distance among the unvisited nodes is infinity. If not, repeat steps 3-6.

Python implementation

First, imports and data formats. The original implementations suggests using namedtuple for storing edge data. We'll do exactly that, but we'll add a default value to the cost argument. There are many ways to do that, find what suits you best.

from collections import deque, namedtuple


# we'll use infinity as a default distance to nodes.
inf = float('inf')
Edge = namedtuple('Edge', 'start, end, cost')


def make_edge(start, end, cost=1):
    return Edge(start, end, cost)

Let's initialize our data:

class Graph:
    def __init__(self, edges):
        # let's check that the data is right
        wrong_edges = [i for i in edges if len(i) not in [2, 3]]
        if wrong_edges:
            raise ValueError('Wrong edges data: {}'.format(wrong_edges))

        self.edges = [make_edge(*edge) for edge in edges]

Let's find the vertices. In the original implementation the vertices are defined in the _ _ init _ _, but we'll need them to update when edges change, so we'll make them a property, they'll be recounted each time we address the property. Probably not the best solution for big graphs, but for small ones it'll go.

    @property
    def vertices(self):
        return set(
            # this piece of magic turns ([1,2], [3,4]) into [1, 2, 3, 4]
            # the set above makes it's elements unique.
            sum(
                ([edge.start, edge.end] for edge in self.edges), []
            )
        )

Now, let's add adding and removing functionality.

    def get_node_pairs(self, n1, n2, both_ends=True):
        if both_ends:
            node_pairs = [[n1, n2], [n2, n1]]
        else:
            node_pairs = [[n1, n2]]
        return node_pairs

    def remove_edge(self, n1, n2, both_ends=True):
        node_pairs = self.get_node_pairs(n1, n2, both_ends)
        edges = self.edges[:]
        for edge in edges:
            if [edge.start, edge.end] in node_pairs:
                self.edges.remove(edge)

    def add_edge(self, n1, n2, cost=1, both_ends=True):
        node_pairs = self.get_node_pairs(n1, n2, both_ends)
        for edge in self.edges:
            if [edge.start, edge.end] in node_pairs:
                return ValueError('Edge {} {} already exists'.format(n1, n2))

        self.edges.append(Edge(start=n1, end=n2, cost=cost))
        if both_ends:
            self.edges.append(Edge(start=n2, end=n1, cost=cost))

Let's find neighbors for every node:

    @property
    def neighbours(self):
        neighbours = {vertex: set() for vertex in self.vertices}
        for edge in self.edges:
            neighbours[edge.start].add((edge.end, edge.cost))

        return neighbours

It's time for the algorithm! I renamed the variables so it would be easier to understand.

    def dijkstra(self, source, dest):
        assert source in self.vertices, 'Such source node doesn\'t exist'

        # 1. Mark all nodes unvisited and store them.
        # 2. Set the distance to zero for our initial node 
        # and to infinity for other nodes.
        distances = {vertex: inf for vertex in self.vertices}
        previous_vertices = {
            vertex: None for vertex in self.vertices
        }
        distances[source] = 0
        vertices = self.vertices.copy()

        while vertices:
            # 3. Select the unvisited node with the smallest distance, 
            # it's current node now.
            current_vertex = min(
                vertices, key=lambda vertex: distances[vertex])

            # 6. Stop, if the smallest distance 
            # among the unvisited nodes is infinity.
            if distances[current_vertex] == inf:
                break

            # 4. Find unvisited neighbors for the current node 
            # and calculate their distances through the current node.
            for neighbour, cost in self.neighbours[current_vertex]:
                alternative_route = distances[current_vertex] + cost

                # Compare the newly calculated distance to the assigned 
                # and save the smaller one.
                if alternative_route < distances[neighbour]:
                    distances[neighbour] = alternative_route
                    previous_vertices[neighbour] = current_vertex

            # 5. Mark the current node as visited 
            # and remove it from the unvisited set.
            vertices.remove(current_vertex)


        path, current_vertex = deque(), dest
        while previous_vertices[current_vertex] is not None:
            path.appendleft(current_vertex)
            current_vertex = previous_vertices[current_vertex]
        if path:
            path.appendleft(current_vertex)
        return path

Let's use it.

graph = Graph([
    ("a", "b", 7),  ("a", "c", 9),  ("a", "f", 14), ("b", "c", 10),
    ("b", "d", 15), ("c", "d", 11), ("c", "f", 2),  ("d", "e", 6),
    ("e", "f", 9)])

print(graph.dijkstra("a", "e"))
>>> deque(['a', 'c', 'd', 'e'])

The whole code from above:

from collections import deque, namedtuple


# we'll use infinity as a default distance to nodes.
inf = float('inf')
Edge = namedtuple('Edge', 'start, end, cost')


def make_edge(start, end, cost=1):
  return Edge(start, end, cost)


class Graph:
    def __init__(self, edges):
        # let's check that the data is right
        wrong_edges = [i for i in edges if len(i) not in [2, 3]]
        if wrong_edges:
            raise ValueError('Wrong edges data: {}'.format(wrong_edges))

        self.edges = [make_edge(*edge) for edge in edges]

    @property
    def vertices(self):
        return set(
            sum(
                ([edge.start, edge.end] for edge in self.edges), []
            )
        )

    def get_node_pairs(self, n1, n2, both_ends=True):
        if both_ends:
            node_pairs = [[n1, n2], [n2, n1]]
        else:
            node_pairs = [[n1, n2]]
        return node_pairs

    def remove_edge(self, n1, n2, both_ends=True):
        node_pairs = self.get_node_pairs(n1, n2, both_ends)
        edges = self.edges[:]
        for edge in edges:
            if [edge.start, edge.end] in node_pairs:
                self.edges.remove(edge)

    def add_edge(self, n1, n2, cost=1, both_ends=True):
        node_pairs = self.get_node_pairs(n1, n2, both_ends)
        for edge in self.edges:
            if [edge.start, edge.end] in node_pairs:
                return ValueError('Edge {} {} already exists'.format(n1, n2))

        self.edges.append(Edge(start=n1, end=n2, cost=cost))
        if both_ends:
            self.edges.append(Edge(start=n2, end=n1, cost=cost))

    @property
    def neighbours(self):
        neighbours = {vertex: set() for vertex in self.vertices}
        for edge in self.edges:
            neighbours[edge.start].add((edge.end, edge.cost))

        return neighbours

    def dijkstra(self, source, dest):
        assert source in self.vertices, 'Such source node doesn\'t exist'
        distances = {vertex: inf for vertex in self.vertices}
        previous_vertices = {
            vertex: None for vertex in self.vertices
        }
        distances[source] = 0
        vertices = self.vertices.copy()

        while vertices:
            current_vertex = min(
                vertices, key=lambda vertex: distances[vertex])
            vertices.remove(current_vertex)
            if distances[current_vertex] == inf:
                break
            for neighbour, cost in self.neighbours[current_vertex]:
                alternative_route = distances[current_vertex] + cost
                if alternative_route < distances[neighbour]:
                    distances[neighbour] = alternative_route
                    previous_vertices[neighbour] = current_vertex

        path, current_vertex = deque(), dest
        while previous_vertices[current_vertex] is not None:
            path.appendleft(current_vertex)
            current_vertex = previous_vertices[current_vertex]
        if path:
            path.appendleft(current_vertex)
        return path


graph = Graph([
    ("a", "b", 7),  ("a", "c", 9),  ("a", "f", 14), ("b", "c", 10),
    ("b", "d", 15), ("c", "d", 11), ("c", "f", 2),  ("d", "e", 6),
    ("e", "f", 9)])

print(graph.dijkstra("a", "e"))

P.S. For those of us who, like me, read more books about the Witcher than about algorithms, it's Edsger Dijkstra, not Sigismund.

Posted on by:

Discussion

markdown guide
 

Nicely done!

If you are only trying to get from A to B in a graph... then the A* algorithm usually performs slightly better: en.wikipedia.org/wiki/A*_search_al... That's what many SatNav packages use :)

 

Yep! I will write about it soon. Thanks for reading :)

 

This next could be written little bit shorter:

path, current_vertex = deque(), dest
while previous_vertices[current_vertex] is not None:
path.appendleft(current_vertex)
current_vertex = previous_vertices[current_vertex]
if path:
path.appendleft(current_vertex)

----->>>

path, current_vertex = deque(), dest
while current_vertex:
path.appendleft(current_vertex)
current_vertex = previous_vertices[current_vertex]

 

Using Python object-oriented knowledge, I made the following modification to the dijkstra method to make it return the distance instead of the path as a deque object.

return the distance between the nodes
distance_between_nodes = 0
for index in range(1, len(path)):
for thing in self.edges:
if thing.start == path[index - 1] and thing.end == path[index]:
distance_between_nodes += thing.cost
return distance_between_nodes

# return path

 

What changes should i do if i dont want to use the deque() data structure?
Also how to read the graph contents from a file with the below structure
a / b / 3
a / c / 5
...
Source node: a
Destination node: j

 

If we want to know the shortest path and total length at the same time
how to change the code?

 

Using Python object-oriented knowledge, I made the following modification to the dijkstra method:

return the distance between the nodes

    distance_between_nodes = 0
    for index in range(1, len(path)):
        for thing in self.edges:
            if thing.start == path[index - 1] and thing.end == path[index]:
                distance_between_nodes += thing.cost
    return distance_between_nodes       
    # return path
 
 

The algotithm says

6 Stop, if the destination node has been visited.

in the code

6. Stop, if the smallest distance

among the unvisited nodes is infinity.

if distances[current_vertex] == inf:
break

The code visits all nodes even after the destination has been visited.

 

What changes should i do if i dont want to use the deque() data structure?
Also how to read the graph contents from a file with the below structure
a / b / 3
a / c / 5
...
Source node: a
Destination node: j

 

Hello Maria,

First of all, thank you for taking the time to share your knowledge with all of us! I am sure that your code will be of much use to many people, me amongst them!

I actually have two questions.

First: do you know -or do you have heard of- how to change the weights of your graph after each movement?
In my case, I would like to impede my graph to move through certain edges setting them to 'Inf' in each iteration (later, I would remove these 'Inf' values and set them to other ones.

Second: Do you know how to include restrictions to Dijkstra, so that the path between certain vertices goes through a fixed number of edges?

Many thanks in advance, and best regards!

 

Hello back,

I was finally able to find a solution to change the weights dynamically during the search process, however, I am still not sure about how to impose the condition of having a path of length >= N, being N the number of traversed edges. The only idea I have come up with would consist on turning to infinity the last edge towards my destination vertex if the overall distance lies below N. However, this would make this edge no longer available for use for the other paths that would arrive to destination vertex. Any ideas from your side folks? Many thanks in advance, and best regards!

 

This algorithm is working correctly only if the graph is directed,but if the graph is undireted it will not.

To make the algorithm work as directed graph you will have to edit neighbour function as

@property
def neighbours(self):
    neighbours = {vertex: set() for vertex in self.vertices}
    for edge in self.edges:
        neighbours[edge.start].add((edge.end, edge.cost))
        neighbours[edge.end].add((edge.start, edge.cost))
    return neighbours
 

for beginners? sure it's packed with 'advanced' py features. @submit, namedtuple, list comprehentions, you name it!

 

Can you please tell us what the asymptote is in this algorithm and why?

 

Thank you Maria, this is exactly was I looking for... a good code with a good explanation to understand better this algorithm.