How to create a GRU in pytorch

This recipe helps you create a GRU in pytorch

Recipe Objective

How to create a GRU in PyTorch?

This is achieved by using the torch.nn.GRU function which is applying a multi layer gated recurrent unit which is nothing but the GRU RNN to an input sequence.

PyTorch vs Tensorflow - Which One Should You Choose For Your Next Deep Learning Project ?

Step 1 - Import library

import torch

Step 2 - Make GRU

my_gru = torch.nn.GRU(20, 20, 4)

Step 3 - Define input and shape of layers

input_data = torch.randn(10, 4, 20)
h_0_data = torch.randn(4, 4, 20)

Step 4 - Apply GRU

output_data, h_n_data = my_gru(input_data, h_0_data)
print("This is the output data:",output_data, "\n")
print("This is hidden state:",h_n_data)

This is the output data: tensor([[[-3.5257e-01, -6.0325e-01,  3.7681e-01, -3.2584e-01,  5.7271e-01,
           4.0499e-01, -5.2022e-01,  3.5511e-01,  4.2433e-02,  3.2935e-02,
           5.9637e-02,  4.8567e-01, -3.2511e-01, -1.9404e-01, -4.0323e-01,
          -1.9137e-02,  3.2097e-01,  2.0129e-01,  1.8089e-01,  1.0978e-01],
         [-5.9031e-01, -1.8283e-01,  4.1243e-01, -6.0601e-01, -1.6516e+00,
           2.4498e-01, -3.6091e-01, -4.9958e-01,  7.1854e-01, -4.1968e-01,
           2.4215e-01, -1.0387e+00, -4.8143e-01, -2.4346e-01, -1.2174e-01,
          -5.8853e-02,  4.0113e-01,  3.3840e-01,  7.2824e-02, -7.0315e-01],
         [-4.6127e-01, -3.6952e-01, -1.1964e+00, -2.0305e-01,  4.2952e-01,
          -2.9603e-01, -3.9396e-03,  6.5457e-01, -7.4918e-01, -6.7001e-01,
           5.5082e-01, -3.6643e-01,  1.0044e+00, -4.5786e-01, -8.6350e-01,
           7.0952e-01, -5.3744e-01,  2.5078e-01,  1.4593e-01, -2.6031e-01],
         [ 6.6160e-01, -1.0241e+00,  1.0481e-02, -1.0129e+00, -9.5937e-02,
           4.6673e-02,  2.7814e-02, -9.1663e-01,  2.7092e-01,  2.3036e-01,
           9.3286e-01, -9.9219e-01, -9.2160e-01, -2.9476e-01,  1.6657e-02,
           5.3559e-01, -1.3702e-01, -8.7886e-01,  3.9514e-01,  4.3436e-01]],

        [[-1.9122e-01, -9.8273e-02,  1.2223e-01, -1.7883e-01,  5.2069e-01,
          -2.2018e-02, -3.2053e-01,  2.9019e-01, -1.8584e-01, -9.9312e-02,
           1.4876e-02,  2.7383e-01, -2.9829e-01, -2.0966e-01, -2.1641e-01,
          -2.6858e-01,  1.7109e-01,  9.8135e-02,  2.7555e-02, -9.7616e-02],
         [-4.4145e-01,  6.8763e-02,  1.8921e-01, -2.2197e-01, -7.6865e-01,
           1.0955e-01, -3.7669e-01, -2.7032e-01,  3.2424e-01, -5.2213e-01,
           2.0233e-01, -3.9297e-01, -4.1984e-01, -2.5487e-01, -3.0405e-01,
           1.6157e-02,  4.5240e-02,  2.0210e-01, -1.8408e-02, -1.4580e-01],
         [-4.0816e-01, -1.9272e-01, -9.7198e-01, -7.2211e-02,  4.0123e-01,
          -1.6722e-01, -2.5704e-01,  5.6354e-01, -7.1037e-01, -4.9016e-01,
           4.3144e-02, -3.5106e-01,  8.3300e-01, -4.0716e-01, -4.7336e-01,
           3.5242e-01, -4.9447e-01,  3.3389e-01,  7.0814e-02, -2.0190e-01],
         [ 5.2071e-01, -6.5492e-01, -2.6646e-01, -7.1111e-01,  4.2677e-02,
           6.0195e-02,  7.8938e-03, -5.4993e-01,  9.5302e-02,  3.1963e-01,
           5.9517e-01, -4.3954e-01, -8.1247e-01, -1.1960e-01, -1.6329e-01,
           2.9959e-01,  4.6766e-02, -4.6342e-01,  2.3437e-01,  3.3319e-01]],

        [[-1.6752e-01,  1.3031e-01, -4.8660e-02, -1.3289e-01,  4.6559e-01,
          -1.8241e-01, -2.5889e-01,  2.4964e-01, -3.1082e-01, -1.7225e-01,
          -5.2269e-02,  1.8618e-01, -1.7835e-01, -2.3875e-01, -1.4765e-01,
          -3.4576e-01,  5.0683e-02,  7.0277e-02, -5.0934e-02, -7.3643e-02],
         [-4.2515e-01,  1.9330e-01,  1.0895e-02, -1.4034e-01, -2.5120e-01,
           9.6355e-02, -3.3489e-01, -9.6485e-02,  6.2567e-02, -4.4190e-01,
           1.4932e-01, -9.3036e-02, -3.3620e-01, -2.7323e-01, -3.6810e-01,
           1.0660e-01, -5.4985e-03,  4.8381e-02, -1.3594e-01,  1.2797e-01],
         [-3.7478e-01, -2.2461e-02, -8.0193e-01, -5.0541e-02,  3.4762e-01,
          -7.9647e-02, -3.4443e-01,  4.5496e-01, -6.3498e-01, -3.7831e-01,
          -1.4987e-01, -2.5103e-01,  6.7978e-01, -3.5815e-01, -3.1329e-01,
           1.6059e-01, -4.5171e-01,  2.6664e-01,  4.3204e-02, -5.6015e-02],
         [ 3.1383e-01, -3.2084e-01, -3.8000e-01, -5.0376e-01,  1.5643e-01,
           6.2568e-02, -6.4008e-02, -3.0334e-01, -1.5168e-02,  3.0263e-01,
           3.9007e-01, -1.4331e-01, -6.7359e-01, -7.7264e-02, -2.0942e-01,
           1.6851e-01,  1.2699e-01, -2.0729e-01,  8.4433e-02,  2.2303e-01]],

        [[-1.9951e-01,  2.2409e-01, -1.5456e-01, -1.3487e-01,  4.2064e-01,
          -2.0631e-01, -2.6335e-01,  2.1477e-01, -3.4324e-01, -2.0060e-01,
          -1.0170e-01,  1.3098e-01, -5.3913e-02, -2.8088e-01, -1.4762e-01,
          -3.1272e-01, -2.5395e-02,  3.9306e-02, -8.6915e-02,  3.5670e-03],
         [-4.0231e-01,  2.4257e-01, -1.3074e-01, -1.4240e-01,  2.0085e-02,
           9.6720e-02, -3.1516e-01,  1.1341e-02, -1.0017e-01, -3.0291e-01,
           1.0922e-01, -3.2604e-03, -2.5442e-01, -2.8261e-01, -3.8497e-01,
           1.5811e-01, -5.4821e-03, -4.6344e-02, -1.9841e-01,  2.0822e-01],
         [-3.2551e-01,  9.0523e-02, -6.7393e-01, -7.5704e-02,  3.1930e-01,
          -1.7572e-02, -3.6845e-01,  3.3534e-01, -5.4765e-01, -2.7119e-01,
          -1.8212e-01, -1.5886e-01,  5.3409e-01, -3.2419e-01, -2.7555e-01,
           8.8708e-02, -3.9941e-01,  1.2664e-01,  4.4510e-02,  8.6389e-02],
         [ 1.3954e-01, -9.3555e-02, -4.1780e-01, -3.7242e-01,  2.3935e-01,
           5.7587e-02, -1.4371e-01, -1.6685e-01, -8.4725e-02,  2.4596e-01,
           2.5742e-01, -2.0934e-02, -5.2591e-01, -9.3083e-02, -2.2277e-01,
           1.0595e-01,  1.3321e-01, -9.5515e-02, -1.5655e-02,  1.4464e-01]],

        [[-2.2685e-01,  2.4816e-01, -2.1655e-01, -1.5877e-01,  3.8766e-01,
          -1.7526e-01, -2.9837e-01,  1.8130e-01, -3.3772e-01, -2.0187e-01,
          -1.3693e-01,  8.1428e-02,  4.7869e-02, -3.1715e-01, -1.6445e-01,
          -2.4781e-01, -7.4383e-02,  8.9302e-03, -9.4694e-02,  6.8949e-02],
         [-3.6480e-01,  2.4722e-01, -2.2170e-01, -1.5426e-01,  1.5334e-01,
           9.6882e-02, -3.1658e-01,  5.6064e-02, -1.8822e-01, -1.8975e-01,
           6.7152e-02,  9.7271e-03, -1.8717e-01, -2.8540e-01, -3.7849e-01,
           1.6482e-01, -1.5371e-02, -9.8311e-02, -2.0780e-01,  2.1993e-01],
         [-2.8077e-01,  1.5530e-01, -5.7469e-01, -1.1213e-01,  3.1751e-01,
           2.3849e-02, -3.6861e-01,  2.2862e-01, -4.6564e-01, -1.7063e-01,
          -1.6039e-01, -9.9901e-02,  3.9963e-01, -3.1420e-01, -2.8815e-01,
           5.9017e-02, -3.4082e-01, -2.0618e-03,  5.4727e-02,  1.7584e-01],
         [ 1.1074e-02,  4.3741e-02, -4.2191e-01, -2.9566e-01,  2.9465e-01,
           5.1546e-02, -2.0595e-01, -9.0223e-02, -1.2801e-01,  1.8163e-01,
           1.6675e-01,  1.4507e-02, -3.9425e-01, -1.3408e-01, -2.2816e-01,
           7.1458e-02,  1.0304e-01, -6.0851e-02, -6.9813e-02,  1.0447e-01]],

        [[-2.3717e-01,  2.4560e-01, -2.4785e-01, -1.8371e-01,  3.5756e-01,
          -1.3250e-01, -3.3481e-01,  1.5019e-01, -3.1754e-01, -1.8960e-01,
          -1.5631e-01,  4.0424e-02,  1.1359e-01, -3.3659e-01, -1.7638e-01,
          -1.8660e-01, -1.0796e-01, -1.3218e-02, -8.5599e-02,  1.0937e-01],
         [-3.2646e-01,  2.3046e-01, -2.6579e-01, -1.7203e-01,  2.1437e-01,
           9.3096e-02, -3.3264e-01,  6.4429e-02, -2.3362e-01, -1.1887e-01,
           2.5844e-02, -5.9269e-03, -1.3657e-01, -2.7660e-01, -3.5059e-01,
           1.4543e-01, -4.1491e-02, -1.1860e-01, -1.7898e-01,  2.0321e-01],
         [-2.6180e-01,  1.8982e-01, -4.9465e-01, -1.5332e-01,  3.3751e-01,
           4.2177e-02, -3.6131e-01,  1.5174e-01, -3.9683e-01, -9.0574e-02,
          -1.3451e-01, -7.7353e-02,  2.8081e-01, -3.2155e-01, -3.0836e-01,
           3.5021e-02, -2.7743e-01, -8.5117e-02,  6.2857e-02,  2.0566e-01],
         [-7.4263e-02,  1.2216e-01, -4.1295e-01, -2.6124e-01,  3.2838e-01,
           4.0267e-02, -2.5801e-01, -3.8918e-02, -1.6236e-01,  1.2797e-01,
           9.9388e-02,  9.1188e-03, -2.8504e-01, -1.7622e-01, -2.2023e-01,
           4.2337e-02,  6.1741e-02, -5.0961e-02, -8.3473e-02,  7.8634e-02]],

        [[-2.5660e-01,  2.3923e-01, -2.5282e-01, -2.0337e-01,  3.3184e-01,
          -8.9981e-02, -3.5572e-01,  1.2271e-01, -2.8920e-01, -1.8606e-01,
          -1.6362e-01,  2.1673e-03,  1.3392e-01, -3.4578e-01, -1.8833e-01,
          -1.3113e-01, -1.3524e-01, -2.4226e-02, -7.1607e-02,  1.3256e-01],
         [-3.1016e-01,  2.1184e-01, -2.7391e-01, -1.9118e-01,  2.4006e-01,
           8.1777e-02, -3.5206e-01,  6.4537e-02, -2.5456e-01, -9.9274e-02,
          -1.8296e-02, -2.9064e-02, -1.0267e-01, -2.6487e-01, -3.1222e-01,
           1.0917e-01, -7.3661e-02, -1.0781e-01, -1.4101e-01,  1.7699e-01],
         [-2.4569e-01,  2.0318e-01, -4.3312e-01, -1.9013e-01,  3.4607e-01,
           4.6752e-02, -3.6065e-01,  9.9960e-02, -3.3949e-01, -3.2186e-02,
          -1.1848e-01, -7.2162e-02,  1.8711e-01, -3.2674e-01, -3.0393e-01,
           1.8445e-02, -2.3079e-01, -1.2937e-01,  7.1797e-02,  1.9242e-01],
         [-1.3378e-01,  1.6098e-01, -4.0057e-01, -2.5365e-01,  3.5027e-01,
           3.0125e-02, -2.9759e-01, -9.3294e-03, -1.8193e-01,  8.2536e-02,
           4.6250e-02, -1.6784e-02, -2.0397e-01, -2.1583e-01, -2.0633e-01,
           1.8390e-02,  2.2680e-02, -4.5915e-02, -7.6309e-02,  6.5137e-02]],

        [[-2.6803e-01,  2.2797e-01, -2.4938e-01, -2.2410e-01,  3.0997e-01,
          -5.8781e-02, -3.6767e-01,  9.6937e-02, -2.6660e-01, -1.7761e-01,
          -1.6419e-01, -3.4846e-02,  1.2211e-01, -3.4843e-01, -1.8247e-01,
          -9.2400e-02, -1.5474e-01, -3.0171e-02, -4.9615e-02,  1.3579e-01],
         [-2.9673e-01,  2.0612e-01, -2.7944e-01, -2.0105e-01,  2.5192e-01,
           7.1161e-02, -3.5872e-01,  6.0554e-02, -2.6852e-01, -9.4308e-02,
          -4.5881e-02, -4.5210e-02, -8.1419e-02, -2.6007e-01, -2.7438e-01,
           7.7579e-02, -1.0340e-01, -9.7610e-02, -1.0478e-01,  1.5510e-01],
         [-2.3949e-01,  2.0492e-01, -3.8762e-01, -2.1997e-01,  3.4759e-01,
           4.2457e-02, -3.6229e-01,  6.8318e-02, -2.9711e-01, -3.5429e-03,
          -1.1383e-01, -7.7915e-02,  1.2099e-01, -3.2803e-01, -2.8288e-01,
           9.8710e-03, -2.0249e-01, -1.4271e-01,  7.4880e-02,  1.6671e-01],
         [-1.8224e-01,  1.8270e-01, -3.8950e-01, -2.5618e-01,  3.6386e-01,
           2.6625e-02, -3.1981e-01,  5.2382e-03, -1.8719e-01,  4.2516e-02,
           5.4355e-03, -4.8995e-02, -1.5488e-01, -2.5436e-01, -1.9095e-01,
           1.9910e-03, -1.2926e-02, -3.9641e-02, -6.0316e-02,  5.7426e-02]],

        [[-2.8018e-01,  2.2370e-01, -2.3841e-01, -2.4390e-01,  2.8889e-01,
          -4.8420e-02, -3.7180e-01,  8.7240e-02, -2.5285e-01, -1.8025e-01,
          -1.6758e-01, -6.0894e-02,  1.0125e-01, -3.4295e-01, -1.6344e-01,
          -7.9027e-02, -1.6146e-01, -1.6084e-02, -3.2474e-02,  1.2430e-01],
         [-2.8119e-01,  2.0523e-01, -2.9023e-01, -2.0693e-01,  2.7009e-01,
           6.7453e-02, -3.5589e-01,  5.0499e-02, -2.7776e-01, -8.9518e-02,
          -6.0246e-02, -6.1228e-02, -6.9642e-02, -2.6643e-01, -2.4745e-01,
           5.2321e-02, -1.2166e-01, -9.5882e-02, -7.1536e-02,  1.4752e-01],
         [-2.4141e-01,  2.0273e-01, -3.5681e-01, -2.4153e-01,  3.4291e-01,
           4.0804e-02, -3.6406e-01,  4.9810e-02, -2.6988e-01,  8.0396e-03,
          -1.0789e-01, -8.4890e-02,  7.0776e-02, -3.2307e-01, -2.6066e-01,
           1.3517e-02, -1.8527e-01, -1.4421e-01,  7.3444e-02,  1.4385e-01],
         [-2.1676e-01,  1.9428e-01, -3.8135e-01, -2.5811e-01,  3.7041e-01,
           3.2289e-02, -3.2607e-01,  5.9846e-03, -1.8404e-01,  7.5610e-03,
          -2.2389e-02, -7.8454e-02, -1.3064e-01, -2.8681e-01, -1.8134e-01,
          -2.2661e-05, -4.5871e-02, -4.0483e-02, -4.5466e-02,  6.0245e-02]],

        [[-2.9305e-01,  2.2497e-01, -2.2991e-01, -2.6210e-01,  2.7677e-01,
          -4.7456e-02, -3.7531e-01,  9.1893e-02, -2.4947e-01, -1.8620e-01,
          -1.7134e-01, -7.7954e-02,  8.6340e-02, -3.4087e-01, -1.4247e-01,
          -7.7346e-02, -1.5506e-01,  5.6221e-03, -2.1138e-02,  1.0540e-01],
         [-2.6055e-01,  2.0358e-01, -2.9647e-01, -2.0824e-01,  2.8282e-01,
           7.0376e-02, -3.5000e-01,  3.8520e-02, -2.8129e-01, -8.5425e-02,
          -6.5337e-02, -7.0258e-02, -5.9397e-02, -2.6895e-01, -2.3462e-01,
           3.6933e-02, -1.3267e-01, -1.0049e-01, -4.6274e-02,  1.5483e-01],
         [-2.4602e-01,  2.0742e-01, -3.3668e-01, -2.4800e-01,  3.3161e-01,
           4.7077e-02, -3.5865e-01,  3.8686e-02, -2.5637e-01,  7.3562e-03,
          -9.4601e-02, -8.3786e-02,  3.3584e-02, -3.1317e-01, -2.5079e-01,
           3.0364e-02, -1.8101e-01, -1.4502e-01,  6.6671e-02,  1.3347e-01],
         [-2.3997e-01,  2.0217e-01, -3.7549e-01, -2.5650e-01,  3.7264e-01,
           4.5409e-02, -3.1888e-01,  3.3823e-04, -1.7822e-01, -1.8905e-02,
          -3.6141e-02, -1.0132e-01, -1.2148e-01, -3.1265e-01, -1.8281e-01,
           1.1807e-02, -7.2118e-02, -4.8410e-02, -3.4601e-02,  7.1965e-02]]],
       grad_fn=) 

This is hidden state: tensor([[[-1.5578e-01,  6.3665e-02, -4.5194e-01, -4.7867e-02,  8.1945e-02,
          -4.3716e-01,  2.7361e-01, -2.9341e-02, -3.2566e-01, -3.4417e-01,
           2.3009e-01, -4.1757e-01, -2.4193e-01,  3.0681e-02, -2.2705e-02,
           4.3180e-01,  1.4227e-01, -1.4489e-01, -3.9828e-01, -6.1347e-01],
         [ 2.1799e-01, -9.6391e-02,  2.2520e-01, -2.1957e-01, -1.8771e-01,
          -3.7281e-01, -1.0545e-02, -1.8155e-01,  1.7960e-01,  1.4532e-02,
           1.8363e-01, -3.2759e-02, -1.8658e-01,  3.3290e-01, -2.6345e-02,
           4.9141e-02, -2.1326e-02, -1.1830e-01, -8.7123e-02,  4.5344e-01],
         [ 2.7514e-01, -2.9151e-01,  1.8429e-02, -1.7047e-01,  2.0113e-01,
          -2.4307e-01,  4.4335e-01,  2.8906e-01, -3.2847e-01,  3.4700e-01,
           6.2155e-02,  2.1475e-01,  7.3027e-02,  5.9086e-02, -3.1047e-01,
           1.7391e-01,  1.9032e-01,  1.0473e-01, -1.4716e-01,  6.2468e-02],
         [ 2.3388e-01, -5.2020e-02,  2.0951e-01,  2.7294e-01,  4.3423e-01,
          -6.6956e-01,  6.2341e-01,  7.4768e-01, -2.0953e-01, -3.4832e-01,
          -6.5974e-01,  2.2509e-01,  3.5147e-01, -6.2584e-01,  7.4313e-01,
           2.1495e-02,  3.1210e-01, -2.7008e-02, -2.8151e-01, -5.3942e-01]],

        [[ 1.5385e-01,  2.0180e-01,  2.4708e-01, -2.9517e-03, -2.4830e-01,
           2.3556e-01,  7.6064e-02,  2.0932e-02, -2.7560e-01,  1.4629e-01,
          -1.5155e-01,  1.0428e-02,  2.4841e-01, -5.3213e-03, -1.1825e-01,
          -1.7333e-01,  5.4241e-02, -1.6788e-01,  8.2588e-02,  5.5862e-02],
         [ 1.3654e-01,  1.0345e-02, -7.0530e-02, -1.9001e-01, -6.9209e-02,
           1.6353e-01, -1.4318e-01, -7.0181e-02,  1.4882e-01,  1.2294e-02,
           3.8150e-02,  1.5531e-01,  6.9013e-02, -2.3163e-01, -2.3579e-01,
          -4.5414e-02,  2.0644e-01, -2.1782e-01,  1.9232e-01,  6.9886e-02],
         [ 3.2599e-01,  1.7440e-01,  6.8801e-02, -1.6690e-01, -2.1344e-01,
           6.9699e-02, -7.8946e-02, -1.7150e-01,  8.4177e-02, -6.2729e-03,
           8.0536e-02,  2.6824e-01,  9.2076e-02, -1.1835e-01, -1.9803e-01,
           2.4679e-03,  6.1740e-02,  1.0685e-02,  1.6489e-01,  3.0749e-01],
         [ 4.0204e-01,  3.3234e-01,  3.4477e-01,  2.4157e-02, -2.5119e-01,
           1.7676e-01, -1.2959e-01,  2.3164e-01,  3.9294e-01,  4.1907e-01,
           4.6506e-03,  2.2514e-01, -1.7977e-01, -1.2341e-02, -1.6260e-01,
          -1.7690e-01, -2.4845e-01,  1.8074e-01, -6.3862e-02, -6.1592e-02]],

        [[-5.7067e-02, -2.3470e-01,  1.6510e-01, -1.1555e-01,  7.4884e-02,
           9.3658e-02,  2.9469e-01, -7.5452e-02, -2.4308e-03,  1.1402e-01,
           1.4859e-01,  2.1269e-01,  3.7096e-01, -6.5132e-02,  8.0691e-02,
           1.7443e-01, -2.0023e-01,  8.8065e-02,  3.7643e-02,  2.0633e-01],
         [-9.6098e-02,  3.5538e-02,  1.3694e-01, -8.3510e-02,  1.8905e-01,
           1.9881e-01,  1.9440e-01, -6.3709e-02, -1.4722e-01,  2.6855e-02,
           4.8369e-02, -5.7355e-02,  2.7774e-01, -3.0916e-02,  4.7210e-02,
           6.7662e-02, -1.6999e-01, -1.9481e-03, -1.8549e-01,  1.0833e-01],
         [-1.2830e-01, -3.5909e-02,  2.3784e-02, -1.5410e-01,  1.6742e-01,
           2.3925e-01,  2.6808e-01, -7.7711e-03, -8.8175e-02, -3.5569e-02,
           5.8349e-02, -2.9260e-02,  3.3348e-01,  2.7489e-02,  1.6638e-01,
           2.1644e-02, -8.2408e-02,  3.0857e-03, -6.1547e-02,  1.5024e-01],
         [-2.0877e-01, -2.3510e-01,  1.5869e-01, -7.5440e-02,  2.0076e-01,
           2.9316e-01,  3.3874e-01, -4.1265e-02, -7.8175e-02, -7.7619e-02,
           2.3794e-01, -2.0346e-02,  2.5243e-01,  1.9528e-01,  2.2373e-01,
           9.2668e-02, -1.9082e-01, -7.5764e-02, -1.9075e-01,  1.1805e-01]],

        [[-2.9305e-01,  2.2497e-01, -2.2991e-01, -2.6210e-01,  2.7677e-01,
          -4.7456e-02, -3.7531e-01,  9.1893e-02, -2.4947e-01, -1.8620e-01,
          -1.7134e-01, -7.7954e-02,  8.6340e-02, -3.4087e-01, -1.4247e-01,
          -7.7346e-02, -1.5506e-01,  5.6221e-03, -2.1138e-02,  1.0540e-01],
         [-2.6055e-01,  2.0358e-01, -2.9647e-01, -2.0824e-01,  2.8282e-01,
           7.0376e-02, -3.5000e-01,  3.8520e-02, -2.8129e-01, -8.5425e-02,
          -6.5337e-02, -7.0258e-02, -5.9397e-02, -2.6895e-01, -2.3462e-01,
           3.6933e-02, -1.3267e-01, -1.0049e-01, -4.6274e-02,  1.5483e-01],
         [-2.4602e-01,  2.0742e-01, -3.3668e-01, -2.4800e-01,  3.3161e-01,
           4.7077e-02, -3.5865e-01,  3.8686e-02, -2.5637e-01,  7.3562e-03,
          -9.4601e-02, -8.3786e-02,  3.3584e-02, -3.1317e-01, -2.5079e-01,
           3.0364e-02, -1.8101e-01, -1.4502e-01,  6.6671e-02,  1.3347e-01],
         [-2.3997e-01,  2.0217e-01, -3.7549e-01, -2.5650e-01,  3.7264e-01,
           4.5409e-02, -3.1888e-01,  3.3823e-04, -1.7822e-01, -1.8905e-02,
          -3.6141e-02, -1.0132e-01, -1.2148e-01, -3.1265e-01, -1.8281e-01,
           1.1807e-02, -7.2118e-02, -4.8410e-02, -3.4601e-02,  7.1965e-02]]],
       grad_fn=)

What Users are saying..

profile image

Jingwei Li

Graduate Research assistance at Stony Brook University
linkedin profile url

ProjectPro is an awesome platform that helps me learn much hands-on industrial experience with a step-by-step walkthrough of projects. There are two primary paths to learn: Data Science and Big Data.... Read More

Relevant Projects

Loan Default Prediction Project using Explainable AI ML Models
Loan Default Prediction Project that employs sophisticated machine learning models, such as XGBoost and Random Forest and delves deep into the realm of Explainable AI, ensuring every prediction is transparent and understandable.

End-to-End Snowflake Healthcare Analytics Project on AWS-1
In this Snowflake Healthcare Analytics Project, you will leverage Snowflake on AWS to predict patient length of stay (LOS) in hospitals. The prediction of LOS can help in efficient resource allocation, lower the risk of staff/visitor infections, and improve overall hospital functioning.

Isolation Forest Model and LOF for Anomaly Detection in Python
Credit Card Fraud Detection Project - Build an Isolation Forest Model and Local Outlier Factor (LOF) in Python to identify fraudulent credit card transactions.

Loan Eligibility Prediction using Gradient Boosting Classifier
This data science in python project predicts if a loan should be given to an applicant or not. We predict if the customer is eligible for loan based on several factors like credit score and past history.

Abstractive Text Summarization using Transformers-BART Model
Deep Learning Project to implement an Abstractive Text Summarizer using Google's Transformers-BART Model to generate news article headlines.

Build Regression (Linear,Ridge,Lasso) Models in NumPy Python
In this machine learning regression project, you will learn to build NumPy Regression Models (Linear Regression, Ridge Regression, Lasso Regression) from Scratch.

MLOps AWS Project on Topic Modeling using Gunicorn Flask
In this project we will see the end-to-end machine learning development process to design, build and manage reproducible, testable, and evolvable machine learning models by using AWS

Machine Learning Project to Forecast Rossmann Store Sales
In this machine learning project you will work on creating a robust prediction model of Rossmann's daily sales using store, promotion, and competitor data.

Census Income Data Set Project-Predict Adult Census Income
Use the Adult Income dataset to predict whether income exceeds 50K yr based oncensus data.

NLP Project for Multi Class Text Classification using BERT Model
In this NLP Project, you will learn how to build a multi-class text classification model using using the pre-trained BERT model.