I have run into some problems when trying to train a network that fits some multivariate quadratic function, or the Euclidean distance between 2 points in a 3-dimensional space, where they are 'pretty far from each other'(typically between $2.1\times10^7$ and $2.4\times10^7$).
To begin with, I have a whole bunch of data from a simulation that features the coordinates of both points, and the distance between them calculated with exactly the same formula as what I attempt to fit, stated as below:
$d=\sqrt{\Delta x^2+\Delta y^2+\Delta z^2}$.
The input data are a set of 3 decimals of the difference of coordinates of both points, one really close to the origin point and another faraway. (Originally 6 when I didn't do the subtraction myself.) The set is generated by a simulation with numpy.random(size = 3)
and then scaled up with a proper multiplier.
and with Keras I constructed something like this:
train_f, test_f, train_t, test_t = train_test_split(X_dist, Y_dist, test_size = 0.3)
nn_r = Sequential()
nn_r.add(InputLayer(input_shape = (train_f.shape[1], )))
nn_r.add(Dense(6, activation = 'linear'))
nn_r.add(Dense(12, activation = LeakyReLU()))
nn_r.add(Dense(12, activation = LeakyReLU()))
nn_r.add(Dense(1, activation = 'linear'))
nn_r.compile(loss = 'mse', optimizer = Adam(), metrics = ['mae', 'mse'])
es_r = EarlyStopping(monitor = 'loss', patience = 30)
lr_r = ReduceLROnPlateau(monitor = 'loss')
nn_r.fit(train_f, train_t, epochs = ep, verbose = 1, callbacks = [es_r, lr_r])
score_r = nn_r.evaluate(x = test_f, y = test_t)
The network does accept the design and starts well, but never giving a reasonable result; mse reaches somewhere above $1\times10^9$ in the end.
I have tried normalizing the input with sklearn.preprocessing
, which did accelerate the training process but did not help with the accuracy; modifying its structure did affect the result significantly but it never end up in an acceptable result either. I even tried doing the subtraction myself and feed the network with the difference, and that did not work either.
Additionally, the system might need to deal with noisy data so I decided to give NNs a try instead of simply putting some numbers back in the formula above.
I'm not sure where the problem could be from, and how I can make it work better. Thank you in advance.