.png)
يمكنك اعتبارها مكتبة لـ Python ، مما يساعد في تنفيذ المهام بشكل أسرع ، والحوسبة العلمية ، وتحولات الوظائف ، والتعلم العميق ، والشبكات العصبية ، وغير ذلك الكثير.
حول Google JAX
حزمة الحساب الأساسية في Python هي حزمة NumPy التي تحتوي على جميع الوظائف مثل التجميعات وعمليات المتجهات والجبر الخطي ومعالجات المصفوفة والمصفوفة ذات الأبعاد n والعديد من الوظائف المتقدمة الأخرى.
ماذا لو تمكنا من تسريع العمليات الحسابية التي يتم إجراؤها باستخدام NumPy - خاصةً لمجموعات البيانات الضخمة؟
هل لدينا شيء يمكن أن يعمل بشكل جيد على قدم المساواة مع أنواع مختلفة من المعالجات مثل GPU أو TPU ، دون أي تغييرات في التعليمات البرمجية؟
ماذا لو كان النظام قادرًا على إجراء تحويلات وظيفية قابلة للتكوين تلقائيًا وبكفاءة أكبر؟
Google JAX عبارة عن مكتبة (أو إطار عمل ، كما تقول ويكيبيديا) تقوم بذلك بالضبط وربما أكثر من ذلك بكثير. تم تصميمه لتحسين الأداء وأداء مهام التعلم الآلي (ML) والتعلم العميق بكفاءة. يوفر Google JAX ميزات التحويل التالية التي تجعله فريدًا عن مكتبات ML الأخرى ويساعد في الحساب العلمي المتقدم للتعلم العميق والشبكات العصبية:
- التمايز التلقائي
- التوجيه التلقائي
- موازاة تلقائية
- تجميع في الوقت المناسب (JIT)

تستخدم جميع التحويلات XLA (الجبر الخطي المعجل) لتحسين الأداء والذاكرة. XLA هو محرك مترجم محسن خاص بالمجال ينفذ الجبر الخطي ويسرع نماذج TensorFlow. لا يتطلب استخدام XLA فوق كود Python أي تغييرات مهمة في التعليمات البرمجية!
دعنا نستكشف بالتفصيل كل من هذه الميزات.
ميزات Google JAX
يأتي Google JAX مزودًا بوظائف تحويل مهمة قابلة للإنشاء لتحسين الأداء وأداء مهام التعلم العميق بكفاءة أكبر. على سبيل المثال ، الاشتقاق التلقائي للحصول على تدرج دالة وإيجاد مشتقات من أي ترتيب. وبالمثل ، فإن الموازاة التلقائية و JIT لأداء مهام متعددة بشكل متوازي. هذه التحولات هي مفتاح تطبيقات مثل الروبوتات والألعاب وحتى البحث.
وظيفة التحويل القابلة للتركيب هي وظيفة خالصة تقوم بتحويل مجموعة من البيانات إلى نموذج آخر. يطلق عليها اسم قابلة للتركيب لأنها قائمة بذاتها (أي أن هذه الوظائف ليس لها تبعيات مع بقية البرنامج) وتكون بلا حالة (أي أن نفس المدخلات ستؤدي دائمًا إلى نفس المخرجات).
ص (س) = T: (و (س))
في المعادلة أعلاه ، f (x) هي الوظيفة الأصلية التي يتم تطبيق التحويل عليها. Y (x) هي الوظيفة الناتجة بعد تطبيق التحويل.
على سبيل المثال ، إذا كان لديك دالة تسمى "total_bill_amt" ، وتريد النتيجة كتحويل دالة ، يمكنك ببساطة استخدام التحويل الذي تريده ، دعنا نقول التدرج (gradient):
grad_total_bill = grad (total_bill_amt)
من خلال تحويل الوظائف العددية باستخدام وظائف مثل grad () ، يمكننا بسهولة الحصول على مشتقاتها ذات الترتيب الأعلى ، والتي يمكننا استخدامها على نطاق واسع في خوارزميات تحسين التعلم العميق مثل الانحدار ، مما يجعل الخوارزميات أسرع وأكثر كفاءة. وبالمثل ، باستخدام jit () ، يمكننا تجميع برامج Python في الوقت المناسب (كسول).
# 1. التمايز التلقائي
تستخدم Python وظيفة autograd للتمييز تلقائيًا بين NumPy وكود Python الأصلي. يستخدم JAX إصدارًا معدلًا من autograd (على سبيل المثال ، grad) ويجمع بين XLA (الجبر الخطي المعجل) لإجراء تمايز تلقائي والعثور على مشتقات من أي ترتيب لـ GPU (وحدات معالجة الرسوم) و TPU (وحدات معالجة Tensor).]
ملاحظة سريعة حول TPU و GPU و CPU: تدير وحدة المعالجة المركزية أو وحدة المعالجة المركزية جميع العمليات على الكمبيوتر. GPU هو معالج إضافي يعمل على تحسين قوة الحوسبة وتشغيل العمليات المتطورة. TPU هي وحدة قوية تم تطويرها خصيصًا لأحمال العمل المعقدة والثقيلة مثل الذكاء الاصطناعي وخوارزميات التعلم العميق.
على طول نفس خطوط وظيفة autograd ، والتي يمكن أن تميز من خلال الحلقات ، العودية ، الفروع ، وما إلى ذلك ، يستخدم JAX وظيفة grad () للتدرجات ذات الوضع العكسي (backpropagation). يمكننا أيضًا اشتقاق دالة لأي ترتيب باستخدام grad:
city (city (city (sin θ))) (1.0)
التمايز التلقائي للرتبة العليا
كما ذكرنا سابقًا ، فإن grad مفيدة جدًا في إيجاد المشتقات الجزئية للدالة. يمكننا استخدام مشتق جزئي لحساب نزول التدرج لدالة التكلفة فيما يتعلق بمعلمات الشبكة العصبية في التعلم العميق لتقليل الخسائر.
حساب المشتق الجزئي
افترض أن دالة لها متغيرات متعددة ، x و y و z. إن إيجاد مشتق أحد المتغيرات عن طريق الحفاظ على المتغيرات الأخرى ثابتة يسمى مشتقًا جزئيًا. لنفترض أن لدينا وظيفة ،
و (س ، ص ، ع) = س + 2 ص + ع 2
مثال لإظهار المشتق الجزئي
سيكون المشتق الجزئي لـ x هو f / ∂x ، وهو ما يخبرنا كيف تتغير دالة لمتغير عندما يكون الآخرون ثابتًا. إذا قمنا بهذا يدويًا ، يجب أن نكتب برنامجًا للتفاضل ، ونطبقه على كل متغير ، ثم نحسب نزول التدرج اللوني. قد يصبح هذا أمرًا معقدًا ويستغرق وقتًا طويلاً لمتغيرات متعددة.
يقسم التفاضل التلقائي الوظيفة إلى مجموعة من العمليات الأولية ، مثل + ، - ، * ، / أو الخطيئة ، وجيب التمام ، والظل ، والخسارة ، وما إلى ذلك ، ثم يطبق قاعدة السلسلة لحساب المشتق. يمكننا القيام بذلك في الوضعين الأمامي والخلفي.

هذا ليس هو! تحدث كل هذه الحسابات بسرعة كبيرة (حسنًا ، فكر في مليون عملية حسابية مماثلة لما سبق والوقت الذي قد تستغرقه!). تهتم XLA بالسرعة والأداء.
# 2. الجبر الخطي المعجل
لنأخذ المعادلة السابقة. بدون XLA ، ستأخذ العملية الحسابية ثلاث نواة (أو أكثر) ، حيث ستؤدي كل نواة مهمة أصغر. فمثلا،
Kernel k1 -> x * 2y (الضرب)
k2 -> x * 2y + z (إضافة)
k3 -> التخفيض
إذا تم تنفيذ نفس المهمة بواسطة XLA ، فإن نواة واحدة تتولى جميع العمليات الوسيطة عن طريق دمجها. يتم بث النتائج الوسيطة للعمليات الأولية بدلاً من تخزينها في الذاكرة ، وبالتالي توفير الذاكرة وتعزيز السرعة.
# 3. تجميع في الوقت المناسب
يستخدم JAX داخليًا مترجم XLA لزيادة سرعة التنفيذ. يمكن لـ XLA زيادة سرعة وحدة المعالجة المركزية ووحدة معالجة الرسومات و TPU. كل هذا ممكن باستخدام تنفيذ كود JIT. لاستخدام هذا ، يمكننا استخدام jit عبر الاستيراد:
هناك طريقة أخرى وهي تزيين jit على تعريف الوظيفة:
هذا الرمز أسرع بكثير لأن التحويل سيعيد النسخة المترجمة من الكود إلى المتصل بدلاً من استخدام مترجم Python. هذا مفيد بشكل خاص لمدخلات المتجهات ، مثل المصفوفات والمصفوفات.
وينطبق الشيء نفسه على جميع وظائف بايثون الموجودة أيضًا. على سبيل المثال ، وظائف من حزمة NumPy. في هذه الحالة ، يجب علينا استيراد jax.numpy كـ jnp بدلاً من NumPy:
بمجرد القيام بذلك ، يحل كائن مصفوفة JAX الأساسية المسمى DeviceArray محل مصفوفة NumPy القياسية. DeviceArray كسول - يتم الاحتفاظ بالقيم في المسرع لحين الحاجة إليها. هذا يعني أيضًا أن برنامج JAX لا ينتظر عودة النتائج إلى برنامج الاستدعاء (Python) ، وبالتالي بعد إرسال غير متزامن.
# 4. التوجيه التلقائي (vmap)
في عالم نموذجي للتعلم الآلي ، لدينا مجموعات بيانات تحتوي على مليون أو أكثر من نقاط البيانات. على الأرجح ، سنقوم ببعض العمليات الحسابية أو المعالجات على كل أو معظم نقاط البيانات هذه - وهي مهمة تستهلك الكثير من الوقت والذاكرة! على سبيل المثال ، إذا كنت تريد العثور على مربع كل نقطة من نقاط البيانات في مجموعة البيانات ، فإن أول شيء تفكر فيه هو إنشاء حلقة وأخذ المربع واحدًا تلو الآخر - أرغ!
إذا أنشأنا هذه النقاط كمتجهات ، فيمكننا القيام بكل المربعات دفعة واحدة عن طريق إجراء معالجات المتجهات أو المصفوفة على نقاط البيانات باستخدام NumPy المفضل لدينا. وإذا كان بإمكان برنامجك القيام بذلك تلقائيًا - فهل يمكنك طلب أي شيء آخر؟ هذا بالضبط ما تفعله JAX! يمكنه توجيه جميع نقاط البيانات تلقائيًا حتى تتمكن من إجراء أي عمليات عليها بسهولة - مما يجعل الخوارزميات الخاصة بك أسرع وأكثر كفاءة.
يستخدم JAX وظيفة vmap للتحويل التلقائي. ضع في اعتبارك المصفوفة التالية:
من خلال القيام بما ورد أعلاه فقط ، سيتم تنفيذ الطريقة التربيعية لكل نقطة في المصفوفة. ولكن إذا قمت بما يلي:
سيتم تنفيذ مربع الطريقة مرة واحدة فقط لأن نقاط البيانات يتم توجيهها تلقائيًا باستخدام طريقة vmap قبل تنفيذ الوظيفة ، ويتم دفع التكرار لأسفل إلى المستوى الأولي للعملية - مما يؤدي إلى مضاعفة المصفوفة بدلاً من الضرب القياسي ، مما يعطي أداءً أفضل .
# 5. برمجة SPMD (pmap)
تعد برمجة SPMD - أو S ingle P rogram M ultiple D ata ضرورية في سياقات التعلم العميق - غالبًا ما تقوم بتطبيق نفس الوظائف على مجموعات مختلفة من البيانات الموجودة في العديد من وحدات معالجة الرسومات أو وحدات المعالجة المركزية. لدى JAX وظيفة تسمى المضخة ، والتي تسمح بالبرمجة المتوازية على وحدات معالجة رسومات متعددة أو أي مسرع. مثل JIT ، سيتم تجميع البرامج التي تستخدم pmap بواسطة XLA وتنفيذها في وقت واحد عبر الأنظمة. يعمل هذا التوازي التلقائي لكل من الحسابات الأمامية والعكسية.

يمكننا أيضًا تطبيق تحويلات متعددة دفعة واحدة وبأي ترتيب على أي وظيفة على النحو التالي:
pmap (vmap (jit (غراد (f (x)))))
تحويلات متعددة قابلة للتكوين