你已经派生过 gpr-sidl-inv
镜像自地址
https://gitee.com/sduem/gpr-sidl-inv.git
已同步 2025-08-03 18:56:51 +08:00
92
3_combine_dataset.py
普通文件
92
3_combine_dataset.py
普通文件
@@ -0,0 +1,92 @@
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import os
|
||||
import scipy.ndimage
|
||||
from config import Forward_Model_Config as cfg
|
||||
from config import Path_Config as pcfg
|
||||
# Parameters
|
||||
|
||||
data_per_ns = cfg.data_per_ns # data grid num per ns
|
||||
DATA_LENGTH = cfg.data_length # Target length after interpolation (unit: grid number)
|
||||
THRESHOLD = cfg.filter_threthold # Minimum absolute value threshold for valid data
|
||||
|
||||
# Input and output paths
|
||||
INPUT_DATA_FOLDER = pcfg.INPUT_DATA_FOLDER
|
||||
INPUT_LABEL_FOLDER = pcfg.INPUT_LABEL_FOLDER
|
||||
DATA_OUTPUT_FILE = pcfg.dataset_path
|
||||
LABEL_OUTPUT_FILE = pcfg.labelset_path
|
||||
|
||||
# Get all data and label filenames
|
||||
data_files = sorted([f for f in os.listdir(INPUT_DATA_FOLDER) if f.endswith('.txt')])
|
||||
label_files = sorted([f for f in os.listdir(INPUT_LABEL_FOLDER) if f.endswith('.csv')])
|
||||
|
||||
# Extract valid file IDs
|
||||
data_ids = set(os.path.splitext(f)[0] for f in data_files)
|
||||
label_ids = set(os.path.splitext(f)[0] for f in label_files)
|
||||
valid_ids = sorted(data_ids & label_ids) # Only process matching data-label pairs
|
||||
|
||||
# Storage for processed data and labels
|
||||
all_data = []
|
||||
all_labels = []
|
||||
|
||||
for data_id in valid_ids:
|
||||
try:
|
||||
# Define file paths
|
||||
data_path = os.path.join(INPUT_DATA_FOLDER, data_id + ".txt")
|
||||
label_path = os.path.join(INPUT_LABEL_FOLDER, data_id + ".csv")
|
||||
|
||||
# Load data file (1D signal data)
|
||||
raw_data = np.loadtxt(data_path, delimiter=",", skiprows=0)
|
||||
processed_data = raw_data[:] # Remove direct wave component
|
||||
|
||||
# Interpolate data to the target length
|
||||
processed_data = scipy.ndimage.zoom(processed_data, DATA_LENGTH / processed_data.shape[0], order=1)
|
||||
|
||||
# Skip data if the maximum absolute value is below the threshold
|
||||
if np.max(np.abs(processed_data)) < THRESHOLD:
|
||||
print(f"Skipping {data_id} due to low signal amplitude.")
|
||||
continue
|
||||
|
||||
# Append processed data
|
||||
all_data.append(processed_data)
|
||||
|
||||
# Load label file and remove the first column
|
||||
label = np.delete(np.loadtxt(label_path, delimiter=",", skiprows=0), [0], axis=0)
|
||||
all_labels.append(label)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing file {data_id}: {e}")
|
||||
continue
|
||||
|
||||
# Convert lists to 2D arrays
|
||||
combined_data = np.column_stack(all_data) if all_data else np.array([])
|
||||
combined_labels = np.column_stack(all_labels) if all_labels else np.array([])
|
||||
|
||||
# Save processed data
|
||||
if combined_data.size > 0:
|
||||
data_header = np.array([int(id_) for id_ in valid_ids]).reshape(1, -1)
|
||||
|
||||
if data_header.shape[1] != combined_data.shape[1]:
|
||||
print(f"Warning: Mismatch between data header ({data_header.shape[1]}) and data columns ({combined_data.shape[1]}). Trimming header.")
|
||||
data_header = data_header[:, :combined_data.shape[1]]
|
||||
|
||||
combined_data_with_header = np.vstack((data_header, combined_data))
|
||||
np.savetxt(DATA_OUTPUT_FILE, combined_data_with_header, delimiter=",", fmt="%g")
|
||||
print(f"Processed data saved to {DATA_OUTPUT_FILE}")
|
||||
else:
|
||||
print("No valid data files found.")
|
||||
|
||||
# Save processed labels
|
||||
if combined_labels.size > 0:
|
||||
label_header = np.array([int(id_) for id_ in valid_ids]).reshape(1, -1)
|
||||
|
||||
if label_header.shape[1] != combined_labels.shape[1]:
|
||||
print(f"Warning: Mismatch between label header ({label_header.shape[1]}) and label columns ({combined_labels.shape[1]}). Trimming header.")
|
||||
label_header = label_header[:, :combined_labels.shape[1]]
|
||||
|
||||
combined_labels_with_header = np.vstack((label_header, combined_labels))
|
||||
np.savetxt(LABEL_OUTPUT_FILE, combined_labels_with_header, delimiter=",", fmt="%g")
|
||||
print(f"Processed labels saved to {LABEL_OUTPUT_FILE}")
|
||||
else:
|
||||
print("No valid label files found.")
|
||||
|
在新工单中引用
屏蔽一个用户