你已经派生过 gpr-sidl-inv
镜像自地址
https://gitee.com/sduem/gpr-sidl-inv.git
已同步 2025-08-02 18:36:51 +08:00
93 行
3.8 KiB
Python
93 行
3.8 KiB
Python
import pandas as pd
|
|
import numpy as np
|
|
import os
|
|
import scipy.ndimage
|
|
from config import Forward_Model_Config as cfg
|
|
from config import Path_Config as pcfg
|
|
# Parameters
|
|
|
|
data_per_ns = cfg.data_per_ns # data grid num per ns
|
|
DATA_LENGTH = cfg.data_length # Target length after interpolation (unit: grid number)
|
|
THRESHOLD = cfg.filter_threthold # Minimum absolute value threshold for valid data
|
|
|
|
# Input and output paths
|
|
INPUT_DATA_FOLDER = pcfg.INPUT_DATA_FOLDER
|
|
INPUT_LABEL_FOLDER = pcfg.INPUT_LABEL_FOLDER
|
|
DATA_OUTPUT_FILE = pcfg.dataset_path
|
|
LABEL_OUTPUT_FILE = pcfg.labelset_path
|
|
|
|
# Get all data and label filenames
|
|
data_files = sorted([f for f in os.listdir(INPUT_DATA_FOLDER) if f.endswith('.txt')])
|
|
label_files = sorted([f for f in os.listdir(INPUT_LABEL_FOLDER) if f.endswith('.csv')])
|
|
|
|
# Extract valid file IDs
|
|
data_ids = set(os.path.splitext(f)[0] for f in data_files)
|
|
label_ids = set(os.path.splitext(f)[0] for f in label_files)
|
|
valid_ids = sorted(data_ids & label_ids) # Only process matching data-label pairs
|
|
|
|
# Storage for processed data and labels
|
|
all_data = []
|
|
all_labels = []
|
|
|
|
for data_id in valid_ids:
|
|
try:
|
|
# Define file paths
|
|
data_path = os.path.join(INPUT_DATA_FOLDER, data_id + ".txt")
|
|
label_path = os.path.join(INPUT_LABEL_FOLDER, data_id + ".csv")
|
|
|
|
# Load data file (1D signal data)
|
|
raw_data = np.loadtxt(data_path, delimiter=",", skiprows=0)
|
|
processed_data = raw_data[:] # Remove direct wave component
|
|
|
|
# Interpolate data to the target length
|
|
processed_data = scipy.ndimage.zoom(processed_data, DATA_LENGTH / processed_data.shape[0], order=1)
|
|
|
|
# Skip data if the maximum absolute value is below the threshold
|
|
if np.max(np.abs(processed_data)) < THRESHOLD:
|
|
print(f"Skipping {data_id} due to low signal amplitude.")
|
|
continue
|
|
|
|
# Append processed data
|
|
all_data.append(processed_data)
|
|
|
|
# Load label file and remove the first column
|
|
label = np.delete(np.loadtxt(label_path, delimiter=",", skiprows=0), [0], axis=0)
|
|
all_labels.append(label)
|
|
|
|
except Exception as e:
|
|
print(f"Error processing file {data_id}: {e}")
|
|
continue
|
|
|
|
# Convert lists to 2D arrays
|
|
combined_data = np.column_stack(all_data) if all_data else np.array([])
|
|
combined_labels = np.column_stack(all_labels) if all_labels else np.array([])
|
|
|
|
# Save processed data
|
|
if combined_data.size > 0:
|
|
data_header = np.array([int(id_) for id_ in valid_ids]).reshape(1, -1)
|
|
|
|
if data_header.shape[1] != combined_data.shape[1]:
|
|
print(f"Warning: Mismatch between data header ({data_header.shape[1]}) and data columns ({combined_data.shape[1]}). Trimming header.")
|
|
data_header = data_header[:, :combined_data.shape[1]]
|
|
|
|
combined_data_with_header = np.vstack((data_header, combined_data))
|
|
np.savetxt(DATA_OUTPUT_FILE, combined_data_with_header, delimiter=",", fmt="%g")
|
|
print(f"Processed data saved to {DATA_OUTPUT_FILE}")
|
|
else:
|
|
print("No valid data files found.")
|
|
|
|
# Save processed labels
|
|
if combined_labels.size > 0:
|
|
label_header = np.array([int(id_) for id_ in valid_ids]).reshape(1, -1)
|
|
|
|
if label_header.shape[1] != combined_labels.shape[1]:
|
|
print(f"Warning: Mismatch between label header ({label_header.shape[1]}) and label columns ({combined_labels.shape[1]}). Trimming header.")
|
|
label_header = label_header[:, :combined_labels.shape[1]]
|
|
|
|
combined_labels_with_header = np.vstack((label_header, combined_labels))
|
|
np.savetxt(LABEL_OUTPUT_FILE, combined_labels_with_header, delimiter=",", fmt="%g")
|
|
print(f"Processed labels saved to {LABEL_OUTPUT_FILE}")
|
|
else:
|
|
print("No valid label files found.")
|
|
|