Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # Get relevant info from the JSON mixer file.
- affine_transform = mixer["projection"]["affine"]["doubleMatrix"]
- patch_dims = mixer["patchDimensions"]
- patches_per_row = mixer["patchesPerRow"]
- total_patches = mixer["totalPatches"]
- # Path to your TFRecord file
- tfrecord_file = '/content/drive/MyDrive/Colab Notebooks/DL_Book/Chapter_1/output/unet_v1/prediction/prediction_unet_v1.TFRecord'
- # Define the feature description for deserialization
- feature_description = {
- # Create a dictionary describing the features.
- 'prediction': tf.io.FixedLenFeature([], tf.int64),
- 'cropland_etc': tf.io.FixedLenFeature([], tf.float32),
- 'rice': tf.io.FixedLenFeature([], tf.float32),
- 'forest': tf.io.FixedLenFeature([], tf.float32),
- 'urban': tf.io.FixedLenFeature([], tf.float32),
- 'others_etc': tf.io.FixedLenFeature([], tf.float32),
- }
- def _parse_function(proto):
- return tf.io.parse_single_example(proto, feature_description)
- # Create a dataset from the TFRecord file
- raw_dataset = tf.data.TFRecordDataset(tfrecord_file)
- parsed_dataset = raw_dataset.map(_parse_function)
- from osgeo import gdal, osr
- import cv2
- # Initialize an empty array for the entire image
- full_image = np.zeros((patch_dims[0] * (total_patches // patches_per_row),
- patch_dims[1] * patches_per_row, 3), dtype=np.uint8)
- # Iterate over each image in the parsed dataset
- for i, features in enumerate(parsed_dataset):
- img = tf.image.decode_image(features['prediction']).numpy()
- row = i // patches_per_row
- col = i % patches_per_row
- full_image[row * patch_dims[0]:(row + 1) * patch_dims[0],
- col * patch_dims[1]:(col + 1) * patch_dims[1]] = img
- # Create a GeoTIFF file
- driver = gdal.GetDriverByName('GTiff')
- outRaster = driver.Create('output.tif', full_image.shape[1], full_image.shape[0], 3, gdal.GDT_Byte)
- outRaster.SetGeoTransform([affine_transform[2], affine_transform[0], 0,
- affine_transform[5], 0, affine_transform[4]])
- # Set the projection
- outRasterSRS = osr.SpatialReference()
- outRasterSRS.ImportFromEPSG(4326)
- outRaster.SetProjection(outRasterSRS.ExportToWkt())
- # Write the data
- outband = outRaster.GetRasterBand(1)
- outband.WriteArray(full_image[:,:,0])
- outband = outRaster.GetRasterBand(2)
- outband.WriteArray(full_image[:,:,1])
- outband = outRaster.GetRasterBand(3)
- outband.WriteArray(full_image[:,:,2])
- # Flush data
- outRaster.FlushCache()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement