文件
gpr-sidl-inv/3_combine_dataset.py
葛峻恺 699f32f283 program
Signed-off-by: 葛峻恺 <202115006@mail.sdu.edu.cn>
2025-04-07 12:17:39 +00:00

93 行
3.8 KiB
Python

import pandas as pd
import numpy as np
import os
import scipy.ndimage
from config import Forward_Model_Config as cfg
from config import Path_Config as pcfg
# Parameters
data_per_ns = cfg.data_per_ns # data grid num per ns
DATA_LENGTH = cfg.data_length # Target length after interpolation (unit: grid number)
THRESHOLD = cfg.filter_threthold # Minimum absolute value threshold for valid data
# Input and output paths
INPUT_DATA_FOLDER = pcfg.INPUT_DATA_FOLDER
INPUT_LABEL_FOLDER = pcfg.INPUT_LABEL_FOLDER
DATA_OUTPUT_FILE = pcfg.dataset_path
LABEL_OUTPUT_FILE = pcfg.labelset_path
# Get all data and label filenames
data_files = sorted([f for f in os.listdir(INPUT_DATA_FOLDER) if f.endswith('.txt')])
label_files = sorted([f for f in os.listdir(INPUT_LABEL_FOLDER) if f.endswith('.csv')])
# Extract valid file IDs
data_ids = set(os.path.splitext(f)[0] for f in data_files)
label_ids = set(os.path.splitext(f)[0] for f in label_files)
valid_ids = sorted(data_ids & label_ids) # Only process matching data-label pairs
# Storage for processed data and labels
all_data = []
all_labels = []
for data_id in valid_ids:
try:
# Define file paths
data_path = os.path.join(INPUT_DATA_FOLDER, data_id + ".txt")
label_path = os.path.join(INPUT_LABEL_FOLDER, data_id + ".csv")
# Load data file (1D signal data)
raw_data = np.loadtxt(data_path, delimiter=",", skiprows=0)
processed_data = raw_data[:] # Remove direct wave component
# Interpolate data to the target length
processed_data = scipy.ndimage.zoom(processed_data, DATA_LENGTH / processed_data.shape[0], order=1)
# Skip data if the maximum absolute value is below the threshold
if np.max(np.abs(processed_data)) < THRESHOLD:
print(f"Skipping {data_id} due to low signal amplitude.")
continue
# Append processed data
all_data.append(processed_data)
# Load label file and remove the first column
label = np.delete(np.loadtxt(label_path, delimiter=",", skiprows=0), [0], axis=0)
all_labels.append(label)
except Exception as e:
print(f"Error processing file {data_id}: {e}")
continue
# Convert lists to 2D arrays
combined_data = np.column_stack(all_data) if all_data else np.array([])
combined_labels = np.column_stack(all_labels) if all_labels else np.array([])
# Save processed data
if combined_data.size > 0:
data_header = np.array([int(id_) for id_ in valid_ids]).reshape(1, -1)
if data_header.shape[1] != combined_data.shape[1]:
print(f"Warning: Mismatch between data header ({data_header.shape[1]}) and data columns ({combined_data.shape[1]}). Trimming header.")
data_header = data_header[:, :combined_data.shape[1]]
combined_data_with_header = np.vstack((data_header, combined_data))
np.savetxt(DATA_OUTPUT_FILE, combined_data_with_header, delimiter=",", fmt="%g")
print(f"Processed data saved to {DATA_OUTPUT_FILE}")
else:
print("No valid data files found.")
# Save processed labels
if combined_labels.size > 0:
label_header = np.array([int(id_) for id_ in valid_ids]).reshape(1, -1)
if label_header.shape[1] != combined_labels.shape[1]:
print(f"Warning: Mismatch between label header ({label_header.shape[1]}) and label columns ({combined_labels.shape[1]}). Trimming header.")
label_header = label_header[:, :combined_labels.shape[1]]
combined_labels_with_header = np.vstack((label_header, combined_labels))
np.savetxt(LABEL_OUTPUT_FILE, combined_labels_with_header, delimiter=",", fmt="%g")
print(f"Processed labels saved to {LABEL_OUTPUT_FILE}")
else:
print("No valid label files found.")