Re: [Dclib-devel] [dclib-devel] Implementing a GAN with dlib
Brought to you by:
davisking
|
From: Eloi Du B. <elo...@gm...> - 2017-10-22 05:40:20
|
Ok,
I'm starting to dig a little more to understand how dlib is made, I think I
have an idea on how to implement this. I have few questions about forward
vs backpropagation:
I'm making a discriminator defined as:
// A 1 filter; 5x1 conv layer that does 2x downsampling
template <typename SUBNET>
using con2 = dlib::con<1, 1, 5, 1, 2, SUBNET>;
template<class SUBNET>
using DiscriminatorT =
dlib::loss_binary_log<dlib::fc<1,dlib::relu<dlib::max_pool<1, 2, 1, 2,
dlib::relu<dlib::max_pool<1, 2, 1, 2, dlib::relu<con2<SUBNET>>>>>>>>;
This just to detect either a vector of 128 floats follows a gaussian
distribution or not.
Then, my discriminator is trained over some test data, positive and
negative examples (gaussian vs noise).
Now, I'm trying to check if the forward function on the net is giving the
same result as a call to the neural net,
std::vector<float> classificationForTest =
discriminator(testVectors);
resizable_tensor tempTensor;
discriminator.to_tensor(&testVectors[0], &testVectors[0]+1,
tempTensor);
auto & tensorOut = discriminator.subnet().forward(tempTensor);
assert( *(float*)tensorOut.host() == classificationForTest[0]);
This is test passes, the forward pass is what I thought it was.
Now, I'm trying to do a back propagation of the error.
So let's say classificationForTest[0] is -8 and it should be 8, my error is
16.
I'm tempted to do something like (pseudo code):
discriminator.subnet().back_propagate_error(tensor = 16);
My question is:
How to back-propagate and rewind the full network to get the error
introduced by each of my scalars of an input vector applied to the
discriminator?
If I get this, then I know how to solve my GAN as I would inject this error
into the output tensor of the generator.
Many thanks for any help,
2017-10-21 20:08 GMT-05:00 Davis King <dav...@gm...>:
> I would make a network that forks into two subnetworks and then write my
> own loss layer that did whatever you want to do. I doubt you could
> implement any of the GAN papers using a binary classification loss like
> loss_binary_log. They all have very specific loss functions that only make
> sense in the context of GAN.
>
> ------------------------------------------------------------
> ------------------
> Check out the vibrant tech community on one of the world's most
> engaging tech sites, Slashdot.org! http://sdm.link/slashdot
> _______________________________________________
> Dclib-devel mailing list
> Dcl...@li...
> https://lists.sourceforge.net/lists/listinfo/dclib-devel
>
>
|