DEV Community

Yuan Gao

Posted on • Updated on

Advent of Code 2020: Day 06 using Python sets

Another short one, this will be quick.

Things mentioned in this post: set theory, list comprehension, intersection, map function, destructuring/splat

The Challenge Part 1

The challenge talks about a task where you need to find out how many unique members there are for each set. One of the examples given was

``````ab
ac
``````

Regardless of what the actual colour text says about this, ultimately the task is to find the unique set of letters, and count them. In this case, just `a`, `b`, and `c`, or 3.

Python sets

Python has built-in sets, which are very versatile. We can simply grab all the data, and split them into entries as we did before:

``````data = open("input.txt").read().split("\n\n")
``````

An example entry (one member of `data`) looks like this:

``````'donpevkjhymzl\nezyopckdlnvmj'
``````

This has the new-line in it, so we need to strip that out for this first part

``````entry.replace("\n","")
``````

Output

``````'donpevkjhymzlezyopckdlnvmj'
``````

Then, we simply stick this in a `set()` which automatically treats each character as a separate member, and de-duplicates it for us.

``````set(entry.replace("\n", "")
``````

Output

``````{'c', 'd', 'e', 'h', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'v', 'y', 'z'}
``````

The question requires summing the total of this, so this can just go in a `len()` to find how many set members

``````len(set(entry.replace("\n", ""))
``````

Output

``````14
``````

Repeat this over all entries in the dataset:

``````sum([len(set(entry.replace("\n",""))) for entry in open("input.txt").read().split("\n\n")])
``````

That's the whole solution for Part 1

The Challenge Part 2

Part 2 switches things up a bit, instead of having to find unique members of the whole lot, the ask is that you find the common members for each item in each entry. So for the example:

``````ab
ac
``````

Only `a` is common.

So, each entry we have to break down into individual items. Python's sets have a `intersection()` method which can give the common items. However, as we have to find the common items across multiple items per entry, we have to do multiple `intersection()` with it.

So, taking an entry, we can split it into items using the regular `split()`

``````items = entry.split()
``````

Output

``````['donpevkjhymzl', 'ezyopckdlnvmj']
``````

Python has a `set.intersection()` method that takes an intersection of every set argument passed to it. For example, if we make sets out of the two items:

``````set.intersection(set(items[0]), set(items[1]))
``````

Output

``````{'d', 'e', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'v', 'y', 'z'}
``````

However, we have to be able to generalize this for an arbitrary number of items. We can do this with destructuring (or "splatting", or "unpacking"), allowing us to pass an arbitrary (and variable) number of arguments to `set.intersection()`

``````set.intersection(*[set(item) for item in items])
``````

(same output)

Here, we use list-comprehension to apply `set()` to each of the items. We can also use the `map()` function, which does pretty much the same thing:

``````set.intersection(*map(set, items))
``````

(same output)

The `*` here is the "splat" operator, which means "unpack this list, and use each of its members as arguments to the function". The terminology is somewhat unclear here, some call it "splat", some call it "unpack", some call it "destructure", some call it "expanding".

The challenge wants the length of this set, and to sum the lengths of all entries, so it's a case of putting this in a loop, and summing the results:

``````total = 0
for entry in data:
items = entry.split()
common = set.intersection(*map(set, items))
total += len(common)
print("total", total)
``````

Or, to code-golf this down to a single line:

``````sum(len(set.intersection(*map(set, entry.split()))) for entry in open("input.txt").read().split("\n\n"))
``````

The end.
Onwards!