Advertisement
biplovbhandari

Untitled

Apr 23rd, 2024
802
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.42 KB | None | 0 0
  1. # Get relevant info from the JSON mixer file.
  2. affine_transform = mixer["projection"]["affine"]["doubleMatrix"]
  3. patch_dims = mixer["patchDimensions"]
  4. patches_per_row = mixer["patchesPerRow"]
  5. total_patches = mixer["totalPatches"]
  6.  
  7. # Path to your TFRecord file
  8. tfrecord_file = '/content/drive/MyDrive/Colab Notebooks/DL_Book/Chapter_1/output/unet_v1/prediction/prediction_unet_v1.TFRecord'
  9.  
  10. # Define the feature description for deserialization
  11. feature_description = {
  12.     # Create a dictionary describing the features.
  13.     'prediction': tf.io.FixedLenFeature([], tf.int64),
  14.     'cropland_etc': tf.io.FixedLenFeature([], tf.float32),
  15.     'rice': tf.io.FixedLenFeature([], tf.float32),
  16.     'forest': tf.io.FixedLenFeature([], tf.float32),
  17.     'urban': tf.io.FixedLenFeature([], tf.float32),
  18.     'others_etc': tf.io.FixedLenFeature([], tf.float32),
  19. }
  20.  
  21. def _parse_function(proto):
  22.     return tf.io.parse_single_example(proto, feature_description)
  23.  
  24. # Create a dataset from the TFRecord file
  25. raw_dataset = tf.data.TFRecordDataset(tfrecord_file)
  26. parsed_dataset = raw_dataset.map(_parse_function)
  27.  
  28. from osgeo import gdal, osr
  29. import cv2
  30.  
  31. # Initialize an empty array for the entire image
  32. full_image = np.zeros((patch_dims[0] * (total_patches // patches_per_row),
  33.                        patch_dims[1] * patches_per_row, 3), dtype=np.uint8)
  34.  
  35. # Iterate over each image in the parsed dataset
  36. for i, features in enumerate(parsed_dataset):
  37.     img = tf.image.decode_image(features['prediction']).numpy()
  38.     row = i // patches_per_row
  39.     col = i % patches_per_row
  40.     full_image[row * patch_dims[0]:(row + 1) * patch_dims[0],
  41.                col * patch_dims[1]:(col + 1) * patch_dims[1]] = img
  42.  
  43. # Create a GeoTIFF file
  44. driver = gdal.GetDriverByName('GTiff')
  45. outRaster = driver.Create('output.tif', full_image.shape[1], full_image.shape[0], 3, gdal.GDT_Byte)
  46. outRaster.SetGeoTransform([affine_transform[2], affine_transform[0], 0,
  47.                            affine_transform[5], 0, affine_transform[4]])
  48.  
  49. # Set the projection
  50. outRasterSRS = osr.SpatialReference()
  51. outRasterSRS.ImportFromEPSG(4326)
  52. outRaster.SetProjection(outRasterSRS.ExportToWkt())
  53.  
  54. # Write the data
  55. outband = outRaster.GetRasterBand(1)
  56. outband.WriteArray(full_image[:,:,0])
  57. outband = outRaster.GetRasterBand(2)
  58. outband.WriteArray(full_image[:,:,1])
  59. outband = outRaster.GetRasterBand(3)
  60. outband.WriteArray(full_image[:,:,2])
  61.  
  62. # Flush data
  63. outRaster.FlushCache()
  64.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement