Tensorflow for cropping and rendering
In tensorflow, there are many usable features for tasks besides just for learning. One worth mentioning is the function for rendering framed rects for given boundingboxes that can be used for fast evaluation of predictions of an image. But what if we want to get fancy? We have recently faced a task, to crop given areas from an image, process them into features and render them back onto the image. This could be done in keras or tensorflow so that it can be fast and embedded in a model. Or even possibly – overlay an image, not with framed boundingboxes, but alphablend with filled rectangles. To note, we want a feature vector rendering, not a full graphics renderer with tensorflow, which already exists (at least in the form of this example).
Here we will be introducing the concepts with the code, but we have also published the code that should be able to run in ipython notebook.
1) Cropping
There is a nice function tf.image.crop_and_resize in tensorflow, that does the cropping almost exactly how we want it. It needs a bit of care because the input types need an index to the batch of images provided. Our original intent was to have bboxes saved per batch.
Below is the full code, along with a simple test, and an example on how to use the function as a keras layer (we are using it that way 🙂 ).
import itertools import tensorflow as tf import numpy as np from keras.layers import Lambda def tf_crop_and_resize_batches_of_arrays(image_input, boxes_input, crop_size=(10, 10)): """ The dimension of boxes is [batch, maxnumperbatch, 4] # some can be just a bogus-zero-fill The dimension of the output should be [batch, maxnumperbatch, crop_0, crop_1, 1] crops. """ bboxes_per_batch = tf.shape(boxes_input)[1] batch_size = tf.shape(boxes_input)[0] # should be the same as image_input.shape[0] # the goal is to create a [batch, maxnumperbatch] field of values, # which are the same across batch and equal to the batch_id # and then to reshape it in the same way as we do reshape the boxes_input to just tell tf about # each bboxes batch (and image). index_to_batch = tf.tile(tf.expand_dims(tf.range(batch_size), -1), (1, bboxes_per_batch)) # now both get reshaped as tf wants it: boxes_processed = tf.reshape(boxes_input, (-1, 4)) box_ind_processed = tf.reshape(index_to_batch, (-1,)) # the method wants boxes = [num_boxes, 4], box_ind = [num_boxes] to index into the batch # the method returns [num_boxes, crop_height, crop_width, depth] tf_produced_crops = tf.image.crop_and_resize( image_input, boxes_processed, box_ind_processed, crop_size, method='bilinear', extrapolation_value=0, name=None ) new_shape = tf.concat([tf.stack([batch_size, bboxes_per_batch]), tf.shape(tf_produced_crops)[1:]], axis=0) crops_resized_to_original = tf.reshape(tf_produced_crops, new_shape) return crops_resized_to_original def keras_crop_and_resize_batches_of_arrays(image_input, boxes_input, crop_size=(10, 10)): """ A helper function for tf_crop_and_resize_batches_of_arrays, assuming, that the crop_size would be a constant and not a tensorflow operation. """ def f_crop(packed): image, boxes = packed return tf_crop_and_resize_batches_of_arrays(image, boxes, crop_size) return Lambda(f_crop)([image_input, boxes_input]) def test_crops(): # the intended usage: crop_size = (10, 10) image_input = np.ones((2, 200, 200, 1)) boxes_input = np.array([[[0.0, 0.0, 0.5, 0.5], [0.0, 0.0, 0.5, 0.5], [0.0, 0.0, 1.0, 1.0]], [[0.0, 0.0, 0.5, 0.5], [0.0, 0.0, 0.5, 0.5], [0.0, 0.0, 1.0, 1.0]]]) # the dimension of boxes is [batch, maxnumperbatch, 4] # some can be just a bogus-zero-fill with tf.Session() as sess: image_input_ph = tf.placeholder(tf.float32, [None, None, None, None]) boxes_input_ph = tf.placeholder(tf.float32, [None, None, 4]) crop_result = sess.run(tf_crop_and_resize_batches_of_arrays(image_input_ph, boxes_input_ph, crop_size=crop_size), feed_dict={image_input_ph: image_input, boxes_input_ph: boxes_input}) assert np.all(crop_result == 1.0), "when cropping image full of ones, we should get all ones too!" assert crop_result.shape == (boxes_input.shape[0], boxes_input.shape[1], crop_size[0], crop_size[1], image_input.shape[-1])
2) Rendering rectangles
In numpy, we would simply slice a tensor and assign the values. But in tensorflow, assignment is not a valid operation on tensors.
So what options do we have?
- modify variables
- use python operation over the numpy code and make it a tensorflow op
- use custom kernels (which can be the fastest way, but will not be covered here)
- we can process each bbox separately and then sum it over the rendering plane
- we can process all bboxes at each pixel and ask if they belong there
- we can use sparse_to_dense operation, or scatter_nd which should be faster.
How would this work? We could produce lists of points that are inside of the rectangles, but that is not easy to make parallel because each bbox would have a different number of pixels inside.
Fortunately, we are also fans of computer graphics, so we know about spreading filters (more in our beloved colleague’s thesis and below [1][2]) that uses, for nearly the same goal, a cumulative sum operation which is available in tensorflow. We will explain more in a moment.
Let’s imagine a 1d case with a 1d image, where we would need to fill intervals. We cannot tell tensorflow directly to start at a specified left corner and then start filling until the specified right corner is found. But, if we could mark all left corners with +1 and all right corners with -1 and then use cumulative sum (from the left), then at each pixel, we would know how many intervals cover it. We can actually mark the edges/corners by using scatter_nd. Even more, we can also parallel the gathering of data for this function, because the number of edges or corners is, unlike the area, the same.
How can we do it for 2d (or higher dimensions)? Step by step, dimension after dimension – if we first generate the data for top and bottom edges, then we can go with cumulative sum from top to bottom and we would have the whole box filled. We only need to remember that the second phase of the process needs pluses at the top and minuses at the bottom so that it would start at the top and end at the bottom, as the asciiart picture shows:
........... ........... ........... ....+..-... ....++++... ....++++... ........... -> ........... -> ....++++... ....-..+... ....----... ....++++... ........... ........... ........... original after cumsums: (left to right) (top to bottom)
And to produce the pluses and minuses like we want, we will iterate over the corners in python using itertools.product and to alter the signs, so that the right +/- is produced. We will count the number of ones in the corner and index and assign pluses to evens and minus to odds. Notice, that this design choice will make the function produce boxes in any dimension. It is fixed before the tensorflow graph is generated (at compile time) which should be sufficient. A test function is provided in a simple form for only 2d. The whole code is presented below together with a python operation and a test.
def render_bboxes_pyfunc_2d(elems, target_shape): """ 2d only numpy + tf.py_func replacement for render_nd_bboxes_tf_spreading. For testing purposes. """ # target_shape = [dimx, dimy,....] def py_render_boxes_2d(x_boxes_data, out_shape): # x will be a numpy array with the contents of the placeholder below if len(x_boxes_data.shape) <= 2: result = np.zeros(list(out_shape) + [x_boxes_data.shape[-1] - 2 * 2], dtype=np.float32) for box in x_boxes_data: result[box[0]:box[2], box[1]: box[3], :] += box[4:] else: # also batch dimension is provided result = np.zeros([x_boxes_data.shape[0]] + list(out_shape) + [x_boxes_data.shape[-1] - 2 * 2], dtype=np.float32) for i, batch in enumerate(x_boxes_data): for box in batch: result[i, box[0]:box[2], box[1]: box[3], :] += box[4:] return result return tf.py_func(py_render_boxes_2d, [elems, target_shape], tf.float32) def render_nd_bboxes_tf_spreading(elems, target_shape, ndim=2): """ elems: tensor of size [..., n_boxes, 2*ndim + val_dim], where in the last dimension, there are packed edge coordinates and values (of val_dim) to be filled in the specified box. target_shape: list/tuple of ndim entries. returns: rendered image of size [elems(...), target_shape..., val_dim] ('elems(...)' usually means batch_size) """ assert_shape_ndim = tf.Assert(tf.equal(tf.size(target_shape), ndim), [target_shape]) assert_nonempty_data = tf.Assert(tf.greater(tf.shape(elems)[-1], 2*ndim), [elems]) with tf.control_dependencies([assert_shape_ndim, assert_nonempty_data]): ''' In 3d there must be another wall of minuses. looking like that: - + ..... + - so when indexing [0,1] to ltrb... pluses are when there is even number of 0s in corner index, minuses when odd. ''' el_ndim = len(elems.shape) # we do not access this property in tensorflow runtime, but in 'compile time', because, well, number of dimensions # should be known before assert el_ndim >= 2 and el_ndim <= 3, "elements should be in the form of [batch, n, coordinates] or [n, coordinates]" if el_ndim == 3: # we use batch_size dimension also! bboxes_per_batch = tf.shape(elems)[1] batch_size = tf.shape(elems)[0] # should be the same as image_input.shape[0] index_to_batch = tf.tile(tf.expand_dims(tf.range(batch_size), -1), (1, bboxes_per_batch)) index_to_batch = tf.reshape(index_to_batch, (-1, 1)) else: index_to_batch = None val_vector_size = tf.shape(elems)[-1] - 2 * ndim corner_ids = list(itertools.product([0, 1], repeat=ndim)) corners_lists = [] corners_values = [] for corner in corner_ids: plus = sum(corner) {86385ef8c424def43c570938a9943967855f43526a201d16486274d8a74d2e91} 2 == 0 id_from_corner = [i + ndim * c for i, c in enumerate(corner)] # indexes a corner into [left, top, right, bottom] notation corner_coord = tf.gather(elems[..., 0: 2 * ndim], id_from_corner, axis=-1) corner_value = elems[..., 2 * ndim:] * (1 if plus else -1) # last dimension is == val_vector_size if index_to_batch is not None: # if the operation is called in batches, remember to rehape it all into one long list for scatter_nd # and add (concatenate) the batch ids corner_coord = tf.concat([index_to_batch, tf.reshape(corner_coord, (-1, 2))], axis=-1) corner_value = tf.reshape(corner_value, (-1, val_vector_size)) corners_lists.append(corner_coord) corners_values.append(corner_value) indices = tf.concat(corners_lists, axis=0) updates = tf.concat(corners_values, axis=0) shape = tf.concat([tf.shape(elems)[:-2], target_shape, [val_vector_size]], axis=0) dense_orig = tf.scatter_nd( indices, updates, shape=shape, ) dense = dense_orig for dim in range(ndim): # we want to start from the axis before the last one. The last one is the value dimension, and # the first dimensions might be the batched dimensions dense = tf.cumsum(dense, axis=-2-dim, exclusive=False, reverse=False, name=None) return dense def test_render_bboxes_2d(): # test without batch_size dimension elems = [[0, 1, 3, 10, 1], [1, 13, 4, 17, 1], [0, 1, 3, 10, 1]] + [[8, 9, 15, 15, 1]]*1000 target_shape = [20, 20] with tf.Session() as sess: elems_ph = tf.placeholder(tf.int32, [None, None]) shape_ph = tf.placeholder(tf.int32, [None]) start_np = timer() np_result = sess.run(render_bboxes_pyfunc_2d(elems_ph, shape_ph), feed_dict={elems_ph: elems, shape_ph:target_shape}) end_np = timer() elems_ph2 = tf.placeholder(tf.int32, [None, None]) shape_ph2 = tf.placeholder(tf.int32, [None]) start_tf = timer() tf_result = sess.run(render_nd_bboxes_tf_spreading(elems_ph2, shape_ph2, ndim=2), feed_dict={elems_ph2: elems, shape_ph2: target_shape}) # or a list of things. end_tf = timer() assert np.all(np.equal(np_result, tf_result)) print((end_np-start_np, end_tf-start_tf)) def test_render_bboxes_batch_2d(): # test with the batch_size dimension (being 2) elems = [[[0, 1, 3, 10, 1], [1, 13, 4, 17, 1], [0, 1, 3, 10, 1]], [[0, 2, 3, 10, 1], [1, 14, 4, 17, 1], [0, 2, 3, 10, 1]]] target_shape = [20, 20] with tf.Session() as sess: elems_ph = tf.placeholder(tf.int32, [None, None, None]) shape_ph = tf.placeholder(tf.int32, [None]) start_np = timer() np_result = sess.run(render_bboxes_pyfunc_2d(elems_ph, shape_ph), feed_dict={elems_ph: elems, shape_ph:target_shape}) end_np = timer() elems_ph2 = tf.placeholder(tf.int32, [None, None, None]) shape_ph2 = tf.placeholder(tf.int32, [None]) start_tf = timer() tf_result = sess.run(render_nd_bboxes_tf_spreading(elems_ph2, shape_ph2, ndim=2), feed_dict={elems_ph2: elems, shape_ph2: target_shape}) # or a list of things. end_tf = timer() assert tf_result.ndim == 4, "bboxes should be able to be rendered into batches of images" assert tf_result.shape[-1] == len(elems[0][0]) - 2*2 assert np_result.shape[0] == tf_result.shape[0] == len(elems), "we have provided a different number of batches" assert np.all(np.equal(np_result, tf_result)) print((end_np-start_np, end_tf-start_tf))
Conclusion and visualizations:
Note, that we have created two functions that need differently scaled inputs (one needs floats, the second needs integers) because we were preserving the format set by the original tensorflow functions. And for the final visualization – here is the code that builds the computational graph, loads a picture, displays crops, processes the crops (as a toy case, it multiplies the original image with the mask) and outputs a picture:
from matplotlib import pyplot as plt from PIL import Image import requests def test_with_lena(): pic = np.asarray(Image.open(requests.get("https://upload.wikimedia.org/wikipedia/en/7/7d/Lenna_{86385ef8c424def43c570938a9943967855f43526a201d16486274d8a74d2e91}28test_image{86385ef8c424def43c570938a9943967855f43526a201d16486274d8a74d2e91}29.png", stream=True).raw)) boxes_input = np.array([[[100.0 / pic.shape[0], 100.0 / pic.shape[1], 200.0 / pic.shape[0], 200.0 / pic.shape[1]]]]) crop_size = (800, 800) with tf.Session() as sess: image_input_ph = tf.placeholder(tf.float32, [None, None, None, None]) boxes_input_ph = tf.placeholder(tf.float32, [None, None, 4]) boxes_with_ones = tf.to_int32( tf.concat([tf.multiply(boxes_input_ph, tf.to_float(tf.concat([tf.shape(image_input_ph)[1:3], tf.shape(image_input_ph)[1:3]], axis=-1))), tf.ones(tf.concat([tf.shape(boxes_input_ph)[0:2], [1]], axis=-1))], axis=-1)) shape_ph = tf.shape(image_input_ph)[1:3] crop_result, render_result, only_box = sess.run( [tf_crop_and_resize_batches_of_arrays(image_input_ph, boxes_input_ph, crop_size=crop_size), #tf.tile(tf.to_float(render_nd_bboxes_tf_spreading(boxes_with_ones, shape_ph, ndim=2)), [1, 1, 1, 3]) tf.multiply(tf.to_float(render_nd_bboxes_tf_spreading(boxes_with_ones, shape_ph, ndim=2)), image_input_ph), tf.tile(tf.to_float(render_nd_bboxes_tf_spreading(boxes_with_ones, shape_ph, ndim=2)), [1, 1, 1, 3]) ], feed_dict={image_input_ph: np.expand_dims(pic, 0), boxes_input_ph: boxes_input}) fig = plt.figure(figsize=(24, 32)) plt.imshow(np.asarray(pic).astype(dtype='B')) plt.tick_params(left='off', bottom='off', labelleft='off', labelbottom='off') l, t, r, b = [100, 100, 200, 200] plt.gca().add_patch(plt.Rectangle((l, b), r - l, t - b, color='b', fill=False)) plt.savefig('pic.png') fig = plt.figure(figsize=(24, 32)) plt.imshow(np.asarray(crop_result[0, 0, :, :, :]).astype(dtype='B')) plt.tick_params(left='off', bottom='off', labelleft='off', labelbottom='off') plt.savefig('crop.png') fig = plt.figure(figsize=(24, 32)) plt.imshow(np.asarray(render_result[0, :, :, :]).astype(dtype='B')) plt.tick_params(left='off', bottom='off', labelleft='off', labelbottom='off') plt.savefig('render.png') fig = plt.figure(figsize=(24, 32)) plt.imshow(np.asarray(only_box[0, :, :, :]*255).astype(dtype='B')) plt.tick_params(left='off', bottom='off', labelleft='off', labelbottom='off') plt.savefig('only.png')

Sources:
[1] Kosloff, T. J. Fast Image Filters for Depth of Field Post-Processing. PhD thesis, EECS Department, University of California, Berkeley, May 2010.
[2] Kosloff, T. J., Hensley, J., and Barsky, B. A. Fast filter spreading and its applications. Tech. Rep. UCB/EECS-2009-54, EECS Department, University of California, Berkeley, Apr 2009.
Find out more about Rossum’s data extraction technology.