#!/usr/bin/env pythonimport tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
print(tf.__version__)
fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels),(test_images,
test_labels)= fashion_mnist.load_data()
class_names =['T-shirt/top','Trouser','Pullover','Dress','Coat','Sandal','Shirt','Sneaker','Bag','Ankle boot']# Preprocess the data.
plt.figure()
plt.imshow(train_images[0])
plt.colorbar()
plt.gca().grid(False)
plt.title('The first image in the training set')
train_images = train_images /255.0
test_images = test_images /255.0# Display the first 25 images from the training set
plt.figure(figsize=(10,10))for i inrange(25):
plt.subplot(5,5, i+1)
plt.xticks([])
plt.yticks([])
plt.grid('off')
plt.imshow(train_images[i], cmap=plt.cm.binary)
plt.xlabel(class_names[train_labels[i]])
plt.suptitle('The first 25 images from the training set')# Setup the layers
model = keras.Sequential([
keras.layers.Flatten(input_shape=(28,28)),
keras.layers.Dense(128, activation=tf.nn.relu),
keras.layers.Dense(10, activation=tf.nn.softmax)])# Compile the model
model.compile(optimizer=tf.train.AdamOptimizer(),
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])# Train the model
model.fit(train_images, train_labels, epochs=5)# Evaluate accuracy
test_loss, test_acc = model.evaluate(test_images, test_labels)print('Test accuracy: ', test_acc)# Make predictions
predictions = model.predict(test_images)# Plot the first 25 test images, their predicted label, and the true label# Color correct predictions in green, incorrect predictions in red
plt.figure(figsize=(10,10))for i inrange(25):
plt.subplot(5,5, i+1)
plt.xticks([])
plt.yticks([])
plt.grid('off')
plt.imshow(test_images[i], cmap=plt.cm.binary)
predicted_label = np.argmax(predictions[i])
true_label = test_labels[i]if predicted_label == true_label:
color ='green'else:
color ='red'
plt.xlabel("{} ({})".format(
class_names[predicted_label],
class_names[true_label]),
color=color)
plt.suptitle('The first 25 images from the test set')# Show figures.
plt.show()