Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Crop the center of the image in Keras ImageDataGenerator or flow_from_directory

I am trying to crop the center of the images in image data generator using keras. I have images of size 192x192 and I want to crop the center of them so that the output batches will be 150x150 or something similar.

Could I do this immediately in Keras ImageDataGenerator? I guess no, since I saw the the target_size argument in datagenerator smashes the images.

I found this link for random cropping: https://jkjung-avt.github.io/keras-image-cropping/

I have already modified the crop as follows:

def my_crop(img, random_crop_size):
  if K.image_data_format() == 'channels_last':
    # Note: image_data_format is 'channel_last'
    assert img.shape[2] == 3
    height, width = img.shape[0], img.shape[1]
    dy, dx = random_crop_size #input desired output size
    start_y = (height-dy)//2
    start_x = (width-dx)//2
    return img[start_y:start_y+dy, start_x:(dx+start_x), :]
  else:
      assert img.shape[0] == 3
      height, width = img.shape[1], img.shape[2]
      dy, dx = random_crop_size  # input desired output size
      start_y = (height - dy) // 2
      start_x = (width - dx) // 2
      return img[:,start_y:start_y + dy, start_x:(dx + start_x)]

def crop_generator(batches, crop_length):
    '''
    Take as input a Keras ImageGen (Iterator) and generate
    crops from the image batches generated by the original iterator
    '''
    while True:
        batch_x, batch_y = next(batches)
       #print('the shape of tensor batch_x is:', batch_x.shape)
        #print('the shape of tensor batch_y is:', batch_y.shape)
        if K.image_data_format() == 'channels_last':
         batch_crops = np.zeros((batch_x.shape[0], crop_length, crop_length, 3))
        else:
         batch_crops = np.zeros((batch_x.shape[0], 3, crop_length, crop_length))
        for i in range(batch_x.shape[0]):
            batch_crops[i] = my_crop(batch_x[i], (crop_length, crop_length))
        yield (batch_crops, batch_y)

This solution seems to me very slow, please is there any other more efficient way? what would you suggest?

Thanks in Advance

like image 960
Ksm Kls Avatar asked May 23 '18 10:05

Ksm Kls


1 Answers

I tried to solve it in this way:

def crop_generator(batches, crop_length):
  while True:
    batch_x, batch_y = next(batches)
    start_y = (img_height - crop_length) // 2
    start_x = (img_width - crop_length) // 2
    if K.image_data_format() == 'channels_last':
        batch_crops = batch_x[:, start_x:(img_width - start_x), start_y:(img_height - start_y), :]
    else:
        batch_crops = batch_x[:, :, start_x:(img_width - start_x), start_y:(img_height - start_y)]
    yield (batch_crops, batch_y)

still if you have better way please give your suggestions.

like image 59
Ksm Kls Avatar answered Sep 20 '22 15:09

Ksm Kls