Solving a JigSaw puzzle using Neural Nets

Solving JigSaw Puzzle using Neural Nets

towards-data-science

This post was originally published by Shiva Verma at Towards Data Science

Solving a 3x3 grid puzzle is extremely difficult. The following are possible combinations of these puzzles.

2x2 puzzle = 4! = 24 combinations3x3 puzzle = 9! = 362880 comb’ns

To solve a 3x3 puzzle the network has to predict one correct combination out of 362880. This is one more reason why 3x3 the puzzle is a tough one.

Let’s move forward and try to solve a 2x2 Jigsaw puzzle.

There was not any public dataset available for Jigsaw Puzzles, so I had to create it myself. I created the data as follows.

  1. Took a raw dataset containing around 26K animal images.
  2. Cropped all images into a fixed size of 200x200.
  3. Split the images into train, test and validation set.
  4. Cut the images into 4 pieces and randomly rearranged them.
  5. For the training set, I have repeated the previous step 4 times to augment the data.
  6. Finally, we have 92K training images and 2K testing images. I have also separated out 300 images for validation.
  7. The label is an integer array that denotes the correct position of each puzzle piece.

Data Creation Process

This dataset contains both 2x2 and 3x3 puzzle. Following is a data sample of 2x2 grid puzzle. Input is a 200x200 pixel image and label is an array of 4 integers, where each integer tells the correct position of each piece.

A Data Sample

Our goal is to feed this image into a neural net and get an output which is a vector of 4 integers indicating the correct position of each piece.

After trying more than 20 neural net architecture and a lot of trial and error I came up with an optimal design. Which is as follows.

  • First, extract each puzzle piece from the image (total 4).
  • Then pass each piece through the CNN. CNN extracts useful features and outputs a feature vector.
  • We concatenate all 4 feature vectors into one using the Flatten layer.
  • Then we pass this combined vector through a feed-forward network. The last layer of this network gives us a 16 unit long vector.
  • We reshape this 16 unit vector into a matrix of 4x4.

Why do we reshape?

In a normal classification task, neural networks output a score for each class. We convert that score into probability by applying a softmax layer. The class which has the highest probability value is our predicted class. This is how we do classification.

The situation is different here. We want to classify each piece into its correct position (0, 1, 2, 3). And there are 4 such pieces. So we need 4 vectors(for each piece) each having 4 scores(for each position), which is nothing but a 4x4 matrix. Where rows correspond to pieces and columns to score. Finally, we apply a softmax on this output matrix row-wise.

The following is the network diagram.

Network Design

I am using Keras framework for this project. Following is the complete network implemented in the Keras. Which looks fairly simple.

Model Implemented in Keras

As you see, the input shape is (4,100,100,3). Means I am feeding 4 images(puzzle pieces) of shape (100,100,3) as an input to the network.

As you see, I am using Time-Distributed(TD) layers. TD layer applies a given layer multiple times over an input. Here the TD layer will apply the same convolutional layer over 4 input images (line: 5, 9, 13, 17).

In order to use TD layers, we have to give one extra dimension in the input, over which TD layer applies a given layer multiple times. Here we are giving one extra dimension, which is the number of images. As a result, we get 4 feature vectors for all 4 image pieces.

Once the CNN feature extraction is done, we concatenate all the features using the Flatten layer (line: 21). Then pass the vector through a feed-forward network. Reshape the final output to a 4x4 matrix and apply a softmax (line 29, 30).

This task is completely different from a normal classification task. In normal classification task network focuses more on the central region of the image. But in the case of Jigsaw, the edge information is much more important than the central one.

So my CNN architecture is different from the usual one in the following ways.

Padding

I am using some extra padding around the image before passing it through CNN (line: 3). And also padding the feature map before each convolution operation (padding = same) to protect as much edge info as possible.

MaxPooling

I am avoiding pooling layers and using just one MaxPool layer to reduce the feature map size (line: 7). Pooling makes the Network translation invariance, which means even if you rotate of jiggle the object in the image, the network would still detect it. Which is good for any object classification task.

But here we don’t want the network to be translation invariance. Our network should be sensitive to variance. Since our edge information is very sensitive.

Shallow Network

We know that top layers in CNN extract feature like edges, corners, etc. And as we go deep, layers tend to extract features like shape, color distribution, etc. Which are not much relevant for our case, so creating a shallow network will help here.

These all are the important details you need to know about CNN architecture. The rest of the network is fairly simple having 3 feed-forward layers, a reshape layer, and finally a softmax layer.

Finally, I compile my model with sparse_categorical_crossentropy loss and adam optimizer. Our target would be a 4 unit vector telling the correct position of each piece.

Target Vector: [[3],[0],[1],[2]]

I trained the network for 5 epochs. I started with learning rate 0.001 and batch size 64. After each epoch, I am reducing the learning rate and increasing batch size.

While prediction, our network outputs a 4×4 vector, then we select the index having a maximum value in each row, which is nothing but the predicted position. Thus we get a vector of length 4. Using this vector we can also re-arrange the puzzle pieces and visualize them.

After training, I ran the model on 2K unseen puzzles, and the model was able to solve the 80% puzzle correctly. Which is quite fair.

Spread the word

This post was originally published by Shiva Verma at Towards Data Science

Related posts