import tensorflow as tf
import tensorflow_hub as hub
import matplotlib.pyplot as plt
import tempfile
from six.moves.urllib.request import urlopen
from six import BytesIO
import numpy as np
from PIL import Image
from PIL import ImageOps
from PIL import ImageFont
from PIL import ImageDraw
from PIL import ImageColor
import time
module_handle = "https://tfhub.dev/google/openimages_v4/ssd/mobilenet_v2/1"
model = hub.load(module_handle)
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
model.signatures.keys()
KeysView(_SignatureMap({'default': <ConcreteFunction pruned(images) at 0x280801A4D30>}))
detector = model.signatures['default']
def display_image(image):
fig = plt.figure(figsize = (20,15))
plt.grid(False)
plt.imshow(image)
def load_and_process_image(url,new_width = 256,new_height = 256,flag = False):
_,filename = tempfile.mkstemp(suffix = '.jpg')
response = urlopen(url)
image = response.read()
image = BytesIO(image)
image = Image.open(image)
image = ImageOps.fit(image,(new_width,new_height),Image.ANTIALIAS)
image = image.convert('RGB')
image.save(filename,format = 'JPEG',quality = 90)
if flag==True:
display(image)
return filename
image_url = "https://upload.wikimedia.org/wikipedia/commons/6/60/Naxos_Taverna.jpg"
image_width = 1280
image_height = 856
image_path = load_and_process_image(image_url,image_width,
image_height,True)
def draw_bb_on_image(image,xmin,ymin,xmax,ymax,color,font,
thickness,display_str_list = ()):
draw = ImageDraw.Draw(image)
image_width,image_height = image.size
(left,top,right,bottom) = (xmin*image_width, ymin*image_height,
xmax*image_width, ymax*image_height)
draw.line([(left,top), (right,top), (right,bottom),
(left,bottom),(left,top)], width = thickness,
fill = color)
class_name = display_str_list[0]
string_height = font.getsize(class_name)[1]
margin = 0.05
total_string_height = (1 + 2*margin)*string_height
if top > total_string_height:
string_bottom = top
else:
string_bottom = top + total_string_height
string_width,string_height = font.getsize(class_name)
margin = np.ceil(margin)
rect_left = left
rect_top = string_bottom - string_height - 2*margin
rect_right = left + string_width
rect_bottom = string_bottom
draw.rectangle([(rect_left,rect_top), (rect_right,rect_bottom)], fill = color)
string_left = left+margin
string_top = string_bottom - string_height - margin
draw.text((string_left,string_top),class_name,fill = "black",font = font)
def draw_boxes(image,scores,class_names,boxes,
max_boxes=10,min_score=0.1):
colors = list(ImageColor.colormap.values())
font = ImageFont.load_default()
total_num_boxes = boxes.shape[0]
for i in range(min(total_num_boxes,max_boxes)):
if scores[i] >= min_score:
ymin,xmin,ymax,xmax = tuple(boxes[i])
display_str = "{}: {}%".format(class_names[i].decode("ascii"),
int(100*scores[i]))
color = colors[hash(class_names[i])%len(colors)]
thickness = 4
pil_image = Image.fromarray(image)
pil_image = pil_image.convert('RGB')
string_heights = draw_bb_on_image(pil_image,xmin,ymin,xmax,ymax,color,font,
thickness,display_str_list = [display_str])
np.copyto(image, np.array(pil_image))
return image
def load_image(image_path):
image = tf.io.read_file(image_path)
image = tf.image.decode_jpeg(image,channels = 3)
return image
def run_detector(detector,image_path):
image = load_image(image_path)
converted_image = tf.image.convert_image_dtype(image,tf.float32)[tf.newaxis, ...]
t0 = time.time()
result = detector(converted_image)
t1 = time.time()
result = {key:value.numpy() for key,value in result.items()}
time_taken = t1-t0
num_objects_detected = len(result['detection_scores'])
print("Time taken: ",time_taken)
print("Number of objects detected: ",num_objects_detected)
boxes = result['detection_boxes']
class_names = result['detection_class_entities']
scores = result['detection_scores']
obj_detected_image = draw_boxes(image.numpy(),scores,class_names,
boxes)
display_image(obj_detected_image)
run_detector(detector,image_path)
Time taken: 0.14460158348083496 Number of objects detected: 100